{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as pl\n",
    "import numpy as np\n",
    "import torch\n",
    "from scipy.spatial.transform import Rotation\n",
    "import plotly.graph_objects as go\n",
    "import plotly.subplots as sp\n",
    "import trimesh\n",
    "import hydra\n",
    "from omegaconf import OmegaConf, DictConfig\n",
    "import os\n",
    "import time\n",
    "import copy\n",
    "import shutil\n",
    "import open3d as o3d\n",
    "\n",
    "import rootutils\n",
    "rootutils.setup_root(\"../fast3r\", indicator=\".project-root\", pythonpath=True)\n",
    "\n",
    "from fast3r.models.multiview_dust3r_module import MultiViewDUSt3RLitModule\n",
    "from fast3r.dust3r.inference_multiview import inference\n",
    "from fast3r.dust3r.model import FlashDUSt3R\n",
    "from fast3r.dust3r.utils.image import load_images, rgb\n",
    "from fast3r.dust3r.viz import CAM_COLORS, OPENGL, add_scene_cam, cat_meshes, pts3d_to_trimesh\n",
    "\n",
    "\n",
    "pl.ion()\n",
    "\n",
    "\n",
    "def get_reconstructed_scene(\n",
    "    outdir,\n",
    "    model,\n",
    "    device,\n",
    "    silent,\n",
    "    image_size,\n",
    "    filelist,\n",
    "    profiling=False,\n",
    "    dtype=torch.float32,\n",
    "    rotate_clockwise_90=False,\n",
    "    crop_to_landscape=False,\n",
    "):\n",
    "    \"\"\"\n",
    "    from a list of images, run dust3r inference, global aligner.\n",
    "    then run get_3D_model_from_scene\n",
    "    \"\"\"\n",
    "    multiple_views_in_one_sample = load_images(filelist, size=image_size, verbose=not silent, rotate_clockwise_90=rotate_clockwise_90, crop_to_landscape=crop_to_landscape)\n",
    "\n",
    "    # time the inference\n",
    "    start = time.time()\n",
    "    output = inference(multiple_views_in_one_sample, model, device, dtype=dtype, verbose=not silent, profiling=profiling)\n",
    "    end = time.time()\n",
    "    print(f\"Time elapsed: {end - start}\")\n",
    "\n",
    "    return output\n",
    "\n",
    "\n",
    "\n",
    "def plot_rgb_images(views, title=\"RGB Images\", save_image_to_folder=None):\n",
    "    fig = sp.make_subplots(rows=1, cols=len(views), subplot_titles=[f\"View {i} Image\" for i in range(len(views))])\n",
    "\n",
    "    # Plot the RGB images\n",
    "    for i, view in enumerate(views):\n",
    "        img_rgb = view['img'].cpu().numpy().squeeze().transpose(1, 2, 0)  # Shape: (224, 224, 3)\n",
    "        # Rescale RGB values from [-1, 1] to [0, 255]\n",
    "        img_rgb = ((img_rgb + 1) * 127.5).astype(int).clip(0, 255)\n",
    "        \n",
    "        fig.add_trace(go.Image(z=img_rgb), row=1, col=i+1)\n",
    "\n",
    "        if save_image_to_folder:\n",
    "            img_path = os.path.join(save_image_to_folder, f\"view_{i}.png\")\n",
    "            pl.imsave(img_path, img_rgb.astype(np.uint8))\n",
    "\n",
    "    fig.update_layout(\n",
    "        title=title,\n",
    "        margin=dict(l=0, r=0, b=0, t=40)\n",
    "    )\n",
    "\n",
    "    # fig.show()\n",
    "\n",
    "def plot_confidence_maps(preds, title=\"Confidence Maps\", save_image_to_folder=None):\n",
    "    fig = sp.make_subplots(rows=1, cols=len(preds), subplot_titles=[f\"View {i} Confidence\" for i in range(len(preds))])\n",
    "\n",
    "    # Plot the confidence maps\n",
    "    for i, pred in enumerate(preds):\n",
    "        conf = pred['conf'].cpu().numpy().squeeze()\n",
    "        fig.add_trace(go.Heatmap(z=conf, colorscale='turbo', showscale=False), row=1, col=i+1)\n",
    "\n",
    "        if save_image_to_folder:\n",
    "            conf_path = os.path.join(save_image_to_folder, f\"view_{i}_conf.png\")\n",
    "            pl.imsave(conf_path, conf, cmap='turbo')\n",
    "\n",
    "    fig.update_layout(\n",
    "        title=title,\n",
    "        margin=dict(l=0, r=0, b=0, t=40)\n",
    "    )\n",
    "\n",
    "    for i in range(len(preds)):\n",
    "        fig['layout'][f'yaxis{i+1}'].update(autorange='reversed')\n",
    "\n",
    "    # fig.show()\n",
    "\n",
    "def maybe_plot_local_depth_and_conf(preds, title=\"Local Depth and Confidence Maps\", save_image_to_folder=None):\n",
    "    # Define the number of columns based on available keys\n",
    "    num_plots = len(preds)\n",
    "    rows = 2  # one for confidence maps, one for depth maps\n",
    "    cols = num_plots\n",
    "\n",
    "    # Create subplots for both confidence and depth maps\n",
    "    fig = sp.make_subplots(\n",
    "        rows=rows, \n",
    "        cols=cols, \n",
    "        subplot_titles=[f\"View {i+1} Conf\" if 'conf_local' in pred else f\"View {i+1} No Conf\" for i, pred in enumerate(preds)]\n",
    "    )\n",
    "\n",
    "    # Iterate over preds to add confidence and depth maps if the fields exist\n",
    "    for i, pred in enumerate(preds):\n",
    "        # Add confidence map if \"conf_local\" exists\n",
    "        if 'conf_local' in pred:\n",
    "            conf_local = pred['conf_local'].cpu().numpy().squeeze()\n",
    "            fig.add_trace(go.Heatmap(z=conf_local, colorscale='Turbo', showscale=False), row=1, col=i+1)\n",
    "\n",
    "            if save_image_to_folder:\n",
    "                conf_local_path = os.path.join(save_image_to_folder, f\"view_{i}_conf_local.png\")\n",
    "                pl.imsave(conf_local_path, conf_local, cmap='turbo')\n",
    "        \n",
    "        # Add depth map if \"pts3d_local\" exists\n",
    "        if 'pts3d_local' in pred:\n",
    "            # Extract Z values as depth from pts3d_local (XY plane)\n",
    "            depth_local = pred['pts3d_local'][..., 2].cpu().numpy().squeeze()  # Use the Z-coordinate\n",
    "            fig.add_trace(go.Heatmap(z=depth_local, colorscale='Greys', showscale=False), row=2, col=i+1)\n",
    "\n",
    "            if save_image_to_folder:\n",
    "                depth_local_path = os.path.join(save_image_to_folder, f\"view_{i}_depth_local.png\")\n",
    "                pl.imsave(depth_local_path, depth_local, cmap='Greys')\n",
    "        \n",
    "\n",
    "    # Update layout for the figure\n",
    "    fig.update_layout(\n",
    "        title=title,\n",
    "        margin=dict(l=0, r=0, b=0, t=40)\n",
    "    )\n",
    "\n",
    "    # Reverse the y-axis for each subplot for consistency\n",
    "    for i in range(num_plots):\n",
    "        if 'conf_local' in preds[i]:\n",
    "            fig['layout'][f'yaxis{i*2+1}'].update(autorange='reversed')\n",
    "        if 'pts3d_local' in preds[i]:\n",
    "            fig['layout'][f'yaxis{i*2+2}'].update(autorange='reversed')\n",
    "\n",
    "    # fig.show()\n",
    "\n",
    "def plot_3d_points_with_colors(preds, views, title=\"3D Points Visualization\", flip_axes=False, as_mesh=False, min_conf_thr_percentile=80, export_ply_path=None):\n",
    "    fig = go.Figure()\n",
    "\n",
    "    all_points = []\n",
    "    all_colors = []\n",
    "    \n",
    "    if as_mesh:\n",
    "        meshes = []\n",
    "        for i, pred in enumerate(preds):\n",
    "            pts3d = pred['pts3d_in_other_view'].cpu().numpy().squeeze()  # Ensure tensor is on CPU and convert to numpy\n",
    "            img_rgb = views[i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0)  # Shape: (224, 224, 3)\n",
    "            conf = pred['conf'].cpu().numpy().squeeze()\n",
    "\n",
    "            # Determine the confidence threshold based on the percentile\n",
    "            conf_thr = np.percentile(conf, min_conf_thr_percentile)\n",
    "\n",
    "            # Filter points based on the confidence threshold\n",
    "            mask = conf > conf_thr\n",
    "\n",
    "            # Rescale RGB values from [-1, 1] to [0, 255]\n",
    "            img_rgb = ((img_rgb + 1) * 127.5).astype(np.uint8).clip(0, 255)\n",
    "\n",
    "            # Generate the mesh for the current view\n",
    "            mesh_dict = pts3d_to_trimesh(img_rgb, pts3d, valid=mask)\n",
    "            meshes.append(mesh_dict)\n",
    "\n",
    "        # Concatenate all meshes\n",
    "        combined_mesh = trimesh.Trimesh(**cat_meshes(meshes))\n",
    "\n",
    "        # Flip axes if needed\n",
    "        if flip_axes:\n",
    "            combined_mesh.vertices[:, [1, 2]] = combined_mesh.vertices[:, [2, 1]]\n",
    "            combined_mesh.vertices[:, 2] = -combined_mesh.vertices[:, 2]\n",
    "\n",
    "        # Export as .ply if the path is provided\n",
    "        if export_ply_path:\n",
    "            combined_mesh.export(export_ply_path)\n",
    "\n",
    "        # Add the combined mesh to the plotly figure\n",
    "        vertex_colors = combined_mesh.visual.vertex_colors[:, :3]  # Ensure the colors are in RGB format\n",
    "        # Map vertex colors to face colors\n",
    "        face_colors = []\n",
    "        for face in combined_mesh.faces:\n",
    "            face_colors.append(np.mean(vertex_colors[face], axis=0))\n",
    "        face_colors = np.array(face_colors).astype(int)\n",
    "        face_colors = ['rgb({}, {}, {})'.format(r, g, b) for r, g, b in face_colors]\n",
    "\n",
    "        fig.add_trace(go.Mesh3d(\n",
    "            x=combined_mesh.vertices[:, 0], \n",
    "            y=combined_mesh.vertices[:, 1], \n",
    "            z=combined_mesh.vertices[:, 2],\n",
    "            i=combined_mesh.faces[:, 0], \n",
    "            j=combined_mesh.faces[:, 1], \n",
    "            k=combined_mesh.faces[:, 2],\n",
    "            facecolor=face_colors,\n",
    "            opacity=0.5,\n",
    "            name=\"Combined Mesh\"\n",
    "        ))\n",
    "    else:\n",
    "        # Loop through each set of points in preds\n",
    "        for i, pred in enumerate(preds):\n",
    "            pts3d = pred['pts3d_in_other_view'].cpu().numpy().squeeze()  # Ensure tensor is on CPU and convert to numpy\n",
    "            img_rgb = views[i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0)  # Shape: (224, 224, 3)\n",
    "            conf = pred['conf'].cpu().numpy().squeeze()\n",
    "\n",
    "            # Determine the confidence threshold based on the percentile\n",
    "            conf_thr = np.percentile(conf, min_conf_thr_percentile)\n",
    "\n",
    "            # Flatten the points and colors\n",
    "            x, y, z = pts3d[..., 0].flatten(), pts3d[..., 1].flatten(), pts3d[..., 2].flatten()\n",
    "            r, g, b = img_rgb[..., 0].flatten(), img_rgb[..., 1].flatten(), img_rgb[..., 2].flatten()\n",
    "            conf_flat = conf.flatten()\n",
    "\n",
    "            # Apply confidence mask\n",
    "            mask = conf_flat > conf_thr\n",
    "            x, y, z = x[mask], y[mask], z[mask]\n",
    "            r, g, b = r[mask], g[mask], b[mask]\n",
    "\n",
    "            # Collect points and colors for exporting\n",
    "            all_points.append(np.vstack([x, y, z]).T)\n",
    "            all_colors.append(np.vstack([r, g, b]).T)\n",
    "\n",
    "            # Rescale RGB values from [-1, 1] to [0, 255]\n",
    "            r = ((r + 1) * 127.5).astype(int).clip(0, 255)\n",
    "            g = ((g + 1) * 127.5).astype(int).clip(0, 255)\n",
    "            b = ((b + 1) * 127.5).astype(int).clip(0, 255)\n",
    "\n",
    "            colors = ['rgb({}, {}, {})'.format(r[j], g[j], b[j]) for j in range(len(r))]\n",
    "            \n",
    "            # Check the flag and flip axes if needed\n",
    "            if flip_axes:\n",
    "                x, y, z = x, z, y\n",
    "                z = -z\n",
    "\n",
    "            # Add points to the plot\n",
    "            fig.add_trace(go.Scatter3d(\n",
    "                x=x, y=y, z=z,\n",
    "                mode='markers',\n",
    "                marker=dict(size=2, opacity=0.8, color=colors),\n",
    "                name=f\"View {i}\"\n",
    "            ))\n",
    "\n",
    "        # Export as .ply if the path is provided\n",
    "        if export_ply_path:\n",
    "            all_points = np.vstack(all_points)\n",
    "            all_colors = np.vstack(all_colors)\n",
    "            point_cloud = trimesh.PointCloud(vertices=all_points, colors=all_colors)\n",
    "            point_cloud.export(export_ply_path)\n",
    "\n",
    "    fig.update_layout(\n",
    "        title=title,\n",
    "        scene=dict(\n",
    "            xaxis_title='X',\n",
    "            yaxis_title='Y',\n",
    "            zaxis_title='Z'\n",
    "        ),\n",
    "        margin=dict(l=0, r=0, b=0, t=40),\n",
    "        height=1000\n",
    "    )\n",
    "\n",
    "    fig.show()\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import plotly.graph_objects as go\n",
    "from fast3r.dust3r.cloud_opt.init_im_poses import fast_pnp\n",
    "from fast3r.dust3r.viz import auto_cam_size\n",
    "from fast3r.dust3r.viz_plotly import SceneViz\n",
    "from fast3r.dust3r.utils.image import rgb  # Assuming you have this utility for image processing\n",
    "\n",
    "\n",
    "# Function to visualize 3D points and camera poses with SceneViz\n",
    "def plot_3d_points_with_estimated_camera_poses(preds, views, title=\"3D Points and Camera Poses\", flip_axes=False, min_conf_thr_percentile=80, export_ply_path=None, export_html_path=None):\n",
    "    # Initialize SceneViz for visualization\n",
    "    viz = SceneViz()\n",
    "\n",
    "    # Flip axes if requested\n",
    "    if flip_axes:\n",
    "        preds = copy.deepcopy(preds)\n",
    "        for i, pred in enumerate(preds):\n",
    "            pts3d = pred['pts3d_in_other_view']\n",
    "            pts3d = pts3d[..., [0, 2, 1]]  # Swap Y and Z axes\n",
    "            pts3d[..., 2] *= -1  # Flip the sign of the Z axis\n",
    "            pred['pts3d_in_other_view'] = pts3d  # Reassign the modified points back to pred\n",
    "\n",
    "    # Estimate camera poses and focal lengths\n",
    "    poses_c2w, estimated_focals = MultiViewDUSt3RLitModule.estimate_camera_poses(preds, niter_PnP=10)\n",
    "    poses_c2w = poses_c2w[0]  # batch size is 1\n",
    "    estimated_focals = estimated_focals[0]  # batch size is 1\n",
    "    cam_size = max(auto_cam_size(poses_c2w), 0.05)  # Auto-scale based on the point cloud\n",
    "\n",
    "    # Set up point clouds and visualization\n",
    "    for i, (pred, pose_c2w) in enumerate(zip(preds, poses_c2w)):\n",
    "        pts3d = pred['pts3d_in_other_view'].cpu().numpy().squeeze()  # (224, 224, 3)\n",
    "        img_rgb = rgb(views[i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0))  # Shape: (224, 224, 3)\n",
    "        conf = pred['conf'].cpu().numpy().squeeze()\n",
    "\n",
    "        # Determine the confidence threshold based on the percentile\n",
    "        conf_thr = np.percentile(conf, min_conf_thr_percentile)\n",
    "        mask = conf > conf_thr\n",
    "\n",
    "        # Add the point cloud directly to the SceneViz object\n",
    "        viz.add_pointcloud(pts3d, img_rgb, mask=mask, point_size=1.0, view_idx=i)\n",
    "\n",
    "        # Add camera to the visualization\n",
    "        viz.add_camera(\n",
    "            pose_c2w=pose_c2w,  # Estimated camera-to-world pose\n",
    "            focal=estimated_focals[i],  # Estimated focal length for each view\n",
    "            color=np.random.randint(0, 256, size=3),  # Generate a random RGB color for each camera\n",
    "            image=img_rgb,  # Image of the view\n",
    "            cam_size=cam_size,  # Auto-scaled camera size\n",
    "            view_idx=i\n",
    "        )\n",
    "\n",
    "    # Export point clouds and meshes if the path is provided\n",
    "    if export_ply_path:\n",
    "        all_points = []\n",
    "        all_colors = []\n",
    "        for i, pred in enumerate(preds):\n",
    "            pts3d = pred['pts3d_in_other_view'].cpu().numpy().squeeze()\n",
    "            img_rgb = views[i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0)\n",
    "            conf = pred['conf'].cpu().numpy().squeeze()\n",
    "            conf_thr = np.percentile(conf, min_conf_thr_percentile)\n",
    "            mask = conf > conf_thr\n",
    "            all_points.append(pts3d[mask])\n",
    "            all_colors.append(img_rgb[mask])\n",
    "        \n",
    "        all_points = np.vstack(all_points)\n",
    "        all_colors = np.vstack(all_colors)\n",
    "        point_cloud = trimesh.PointCloud(vertices=all_points, colors=all_colors)\n",
    "        point_cloud.export(export_ply_path)\n",
    "    \n",
    "    if export_html_path:\n",
    "        viz.export_html(export_html_path)\n",
    "\n",
    "    # Show the visualization\n",
    "    viz.show()\n",
    "\n",
    "def save_pointmaps_and_camera_parameters_to_folder(preds, save_folder, niter_PnP=100, focal_length_estimation_method='individual'):\n",
    "    \"\"\"\n",
    "    Saves pointmaps and estimated camera parameters to a folder.\n",
    "\n",
    "    Args:\n",
    "        preds (list): List of prediction dictionaries containing point maps and confidence scores.\n",
    "        save_folder (str): Path to the folder where the numpy data structure will be saved.\n",
    "    \"\"\"\n",
    "    # Estimate camera poses and focal lengths\n",
    "    poses_c2w, estimated_focals = MultiViewDUSt3RLitModule.estimate_camera_poses(preds, niter_PnP=niter_PnP, focal_length_estimation_method=focal_length_estimation_method)\n",
    "    poses_c2w = poses_c2w[0]  # Assuming batch size is 1\n",
    "    estimated_focals = estimated_focals[0]  # Assuming batch size is 1\n",
    "\n",
    "    # Initialize lists to hold the data\n",
    "    global_pointmap = []\n",
    "    global_confidence = []\n",
    "    local_pointmap = []\n",
    "    local_aligned_to_global_pointmap = []\n",
    "    local_confidence = []\n",
    "    estimated_focals_list = []\n",
    "    estimated_poses_c2w_list = []\n",
    "\n",
    "    # Loop over predictions and extract required data\n",
    "    for i, pred in enumerate(preds):\n",
    "        # Extract global point map\n",
    "        pts3d_in_other_view = pred['pts3d_in_other_view'].cpu().numpy().squeeze()  # Shape: H x W x 3\n",
    "        global_pointmap.append(pts3d_in_other_view)\n",
    "        \n",
    "        # Extract global confidence map\n",
    "        conf = pred['conf'].cpu().numpy().squeeze()  # Shape: H x W\n",
    "        global_confidence.append(conf)\n",
    "        \n",
    "        # Extract local point map\n",
    "        pts3d_local = pred['pts3d_local'].cpu().numpy().squeeze()  # Shape: H x W x 3\n",
    "        local_pointmap.append(pts3d_local)\n",
    "\n",
    "        # Extract local aligned to global point map\n",
    "        pts3d_local_aligned = pred['pts3d_local_aligned_to_global'].cpu().numpy().squeeze()  # Shape: H x W x 3\n",
    "        local_aligned_to_global_pointmap.append(pts3d_local_aligned)\n",
    "        \n",
    "        # Extract local confidence map\n",
    "        conf_local = pred['conf_local'].cpu().numpy().squeeze()\n",
    "        local_confidence.append(conf_local)\n",
    "\n",
    "        # Append estimated focal length and camera pose\n",
    "        focal = estimated_focals[i].item() if isinstance(estimated_focals[i], torch.Tensor) else estimated_focals[i]\n",
    "        estimated_focals_list.append(focal)\n",
    "        pose = poses_c2w[i].cpu().numpy() if isinstance(poses_c2w[i], torch.Tensor) else poses_c2w[i]\n",
    "        estimated_poses_c2w_list.append(pose)\n",
    "\n",
    "    # Ensure the save_folder exists\n",
    "    os.makedirs(save_folder, exist_ok=True)\n",
    "\n",
    "    # Create the file name inside the function\n",
    "    save_file = os.path.join(save_folder, 'pointmaps_and_camera_params.npz')\n",
    "\n",
    "    # Save the data to a numpy file\n",
    "    np.savez(\n",
    "        save_file,\n",
    "        global_pointmaps=global_pointmap,\n",
    "        global_confidence_maps=global_confidence,\n",
    "        local_pointmaps=local_pointmap,\n",
    "        local_aligned_to_global_pointmaps=local_aligned_to_global_pointmap,\n",
    "        local_confidence_maps=local_confidence,\n",
    "        estimated_focals=estimated_focals_list,\n",
    "        estimated_poses_c2w=estimated_poses_c2w_list\n",
    "    )\n",
    "\n",
    "\n",
    "def export_combined_ply(preds, views, export_ply_path=None, \n",
    "                        pts3d_key_to_visualize=\"pts3d_local_aligned_to_global\",\n",
    "                        conf_key_to_visualize=\"conf_local\",\n",
    "                        min_conf_thr_percentile=0, flip_axes=False, max_num_points=None, sampling_strategy='uniform'):\n",
    "    all_points = []\n",
    "    all_colors = []\n",
    "\n",
    "    # Loop through each set of points in preds\n",
    "    for i, pred in enumerate(preds):\n",
    "        pts3d = pred[pts3d_key_to_visualize].cpu().numpy().squeeze()  # Ensure tensor is on CPU and convert to numpy\n",
    "        img_rgb = views[i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0)  # Shape: (H, W, 3)\n",
    "        conf = pred[conf_key_to_visualize].cpu().numpy().squeeze()\n",
    "\n",
    "        # Determine the confidence threshold based on the percentile\n",
    "        conf_thr = np.percentile(conf, min_conf_thr_percentile)\n",
    "\n",
    "        # Flatten the points and colors\n",
    "        x, y, z = pts3d[..., 0].flatten(), pts3d[..., 1].flatten(), pts3d[..., 2].flatten()\n",
    "        r, g, b = img_rgb[..., 0].flatten(), img_rgb[..., 1].flatten(), img_rgb[..., 2].flatten()\n",
    "        conf_flat = conf.flatten()\n",
    "\n",
    "        # Apply confidence mask\n",
    "        mask = conf_flat > conf_thr\n",
    "        x, y, z = x[mask], y[mask], z[mask]\n",
    "        r, g, b = r[mask], g[mask], b[mask]\n",
    "\n",
    "        # Rescale RGB values from [-1, 1] to [0, 255]\n",
    "        r = ((r + 1) * 127.5).astype(np.uint8).clip(0, 255)\n",
    "        g = ((g + 1) * 127.5).astype(np.uint8).clip(0, 255)\n",
    "        b = ((b + 1) * 127.5).astype(np.uint8).clip(0, 255)\n",
    "\n",
    "        # Collect points and colors for exporting\n",
    "        points = np.vstack([x, y, z]).T\n",
    "        colors = np.vstack([r, g, b]).T\n",
    "\n",
    "        # Check the flag and flip axes if needed\n",
    "        if flip_axes:\n",
    "            points = points[:, [0, 2, 1]]  # Swap y and z\n",
    "            points[:, 2] = -points[:, 2]   # Invert z-axis\n",
    "\n",
    "        all_points.append(points)\n",
    "        all_colors.append(colors)\n",
    "\n",
    "    all_points = np.vstack(all_points)\n",
    "    all_colors = np.vstack(all_colors)\n",
    "\n",
    "    # If max_num_points is specified, downsample the point cloud using the selected sampling strategy\n",
    "    if max_num_points is not None and len(all_points) > max_num_points:\n",
    "        if sampling_strategy == 'uniform':\n",
    "            # Uniform random sampling\n",
    "            indices = np.random.choice(len(all_points), size=max_num_points, replace=False)\n",
    "            all_points = all_points[indices]\n",
    "            all_colors = all_colors[indices]\n",
    "        elif sampling_strategy == 'voxel':\n",
    "            # Voxel grid downsampling\n",
    "            pcd = o3d.geometry.PointCloud()\n",
    "            pcd.points = o3d.utility.Vector3dVector(all_points)\n",
    "            pcd.colors = o3d.utility.Vector3dVector(all_colors.astype(np.float64) / 255.0)\n",
    "            \n",
    "            # Estimate a voxel size to achieve the desired number of points\n",
    "            # This is a heuristic and may need adjustment\n",
    "            bounding_box = pcd.get_axis_aligned_bounding_box()\n",
    "            extent = bounding_box.get_extent()\n",
    "            volume = extent[0] * extent[1] * extent[2]\n",
    "            voxel_size = (volume / max_num_points) ** (1/3)\n",
    "\n",
    "            down_pcd = pcd.voxel_down_sample(voxel_size)\n",
    "\n",
    "            # Extract downsampled points and colors\n",
    "            all_points = np.asarray(down_pcd.points)\n",
    "            all_colors = (np.asarray(down_pcd.colors) * 255.0).astype(np.uint8)\n",
    "        elif sampling_strategy == 'farthest_point':\n",
    "            # Farthest point downsampling using Open3D\n",
    "            # Note: May be slow for large point clouds\n",
    "            pcd = o3d.geometry.PointCloud()\n",
    "            pcd.points = o3d.utility.Vector3dVector(all_points)\n",
    "            pcd.colors = o3d.utility.Vector3dVector(all_colors.astype(np.float64) / 255.0)\n",
    "\n",
    "            down_pcd = pcd.farthest_point_down_sample(max_num_points)\n",
    "\n",
    "            # Extract downsampled points and colors\n",
    "            all_points = np.asarray(down_pcd.points)\n",
    "            all_colors = (np.asarray(down_pcd.colors) * 255.0).astype(np.uint8)\n",
    "        else:\n",
    "            raise ValueError(f\"Unsupported sampling strategy: {sampling_strategy}\")\n",
    "\n",
    "    # Export as .ply if the path is provided\n",
    "    if export_ply_path:\n",
    "        point_cloud = trimesh.PointCloud(vertices=all_points, colors=all_colors)\n",
    "        point_cloud.export(export_ply_path)\n",
    "\n",
    "    return all_points, all_colors\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "import random\n",
    "\n",
    "data_root = \"../data\"\n",
    "\n",
    "# filelist_train = [\n",
    "#     f\"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000001.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000002.jpg\"\n",
    "# ]\n",
    "\n",
    "# apple\n",
    "# filelist_test = [\n",
    "#     f\"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000200.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000085.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000090.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000170.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/apple/189_20393_38136/images/frame000199.jpg\",\n",
    "# ]\n",
    "\n",
    "\n",
    "# bench test\n",
    "# filelist_test = [\n",
    "#     f\"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000006.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000016.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000026.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000036.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000096.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000126.jpg\",\n",
    "#     # f\"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000156.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/bench/415_57112_110099/images/frame000186.jpg\",\n",
    "# ]\n",
    "\n",
    "# # teddy bear train\n",
    "# filelist_test = [\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000001.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000002.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000003.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000004.jpg\",\n",
    "#     # f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000012.jpg\",\n",
    "#     # f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000022.jpg\",\n",
    "#     # f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000032.jpg\",\n",
    "# ]\n",
    "# teddy bear test\n",
    "# filelist_test = [\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000006.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000016.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000026.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000036.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000096.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000126.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000156.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000186.jpg\",\n",
    "# ]\n",
    "\n",
    "# teddy bear random order\n",
    "# filelist_test = [\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000126.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000036.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000096.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000006.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000026.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000186.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000016.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000156.jpg\",\n",
    "# ]\n",
    "\n",
    "\n",
    "# filelist_test = [\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000006.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000036.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000066.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/teddybear/34_1479_4753/images/frame000096.jpg\",\n",
    "# ]\n",
    "\n",
    "# suitcase test\n",
    "# filelist_test = [\n",
    "#     f\"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000006.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000016.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000026.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000036.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000096.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000126.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000156.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/suitcase/50_2928_8645/images/frame000186.jpg\",\n",
    "# ]\n",
    "\n",
    "# cake test\n",
    "# filelist_test = [\n",
    "#     f\"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000006.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000016.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000026.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000036.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000096.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000126.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000156.jpg\",\n",
    "#     f\"{data_root}/co3d_subset_processed/cake/374_42274_84517/images/frame000186.jpg\",\n",
    "# ]\n",
    "\n",
    "# in-the-wild obj: book\n",
    "# filelist_test = [\n",
    "#     f\"{data_root}/unseen_book/IMG_9837.jpg\",\n",
    "#     f\"{data_root}/unseen_book/IMG_9838.jpg\",\n",
    "#     f\"{data_root}/unseen_book/IMG_9839.jpg\",\n",
    "#     f\"{data_root}/unseen_book/IMG_9840.jpg\",\n",
    "#     f\"{data_root}/unseen_book/IMG_9841.jpg\",\n",
    "#     f\"{data_root}/unseen_book/IMG_9842.jpg\",\n",
    "#     f\"{data_root}/unseen_book/IMG_9843.jpg\",\n",
    "#     f\"{data_root}/unseen_book/IMG_9844.jpg\",\n",
    "# ]\n",
    "\n",
    "# in-the-wild obj: beef jerky\n",
    "# filelist_test = [\n",
    "#     f\"{data_root}/beef_jerky/IMG_0050.jpg\",\n",
    "#     f\"{data_root}/beef_jerky/IMG_0051.jpg\",\n",
    "#     f\"{data_root}/beef_jerky/IMG_0052.jpg\",\n",
    "#     f\"{data_root}/beef_jerky/IMG_0053.jpg\",\n",
    "#     f\"{data_root}/beef_jerky/IMG_0054.jpg\",\n",
    "#     f\"{data_root}/beef_jerky/IMG_0055.jpg\",\n",
    "#     f\"{data_root}/beef_jerky/IMG_0056.jpg\",\n",
    "#     f\"{data_root}/beef_jerky/IMG_0057.jpg\",\n",
    "#     f\"{data_root}/beef_jerky/IMG_0058.jpg\",\n",
    "# ]\n",
    "\n",
    "\n",
    "# ArkitScenes\n",
    "# filelist_test = [\n",
    "#     f\"/datasets01/ARKitScenes/raw/Validation/41069021/vga_wide/41069021_312.125.png\",\n",
    "#     f\"/datasets01/ARKitScenes/raw/Validation/41069021/vga_wide/41069021_313.124.png\",\n",
    "#     f\"/datasets01/ARKitScenes/raw/Validation/41069021/vga_wide/41069021_314.124.png\",\n",
    "#     f\"/datasets01/ARKitScenes/raw/Validation/41069021/vga_wide/41069021_315.123.png\",\n",
    "#     f\"/datasets01/ARKitScenes/raw/Validation/41069021/vga_wide/41069021_316.123.png\",\n",
    "#     f\"/datasets01/ARKitScenes/raw/Validation/41069021/vga_wide/41069021_317.123.png\",\n",
    "#     f\"/datasets01/ARKitScenes/raw/Validation/41069021/vga_wide/41069021_318.122.png\",\n",
    "# ]\n",
    "\n",
    "# HSSD\n",
    "# filelist_test = [\n",
    "#     f\"{data_root}/0_102344022_0/rgb/0000{i:02d}.png\" for i in range(8)\n",
    "# ]\n",
    "\n",
    "# filelist_test = [\n",
    "#     f\"{data_root}/17_102344250_4/rgb/0000{i:02d}.png\" for i in range(0,15)\n",
    "# ]\n",
    "\n",
    "# unseen obj: teddy bear from co3d\n",
    "# filelist_test = [\n",
    "#     \"/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000006.jpg\",\n",
    "#     \"/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000036.jpg\",\n",
    "#     \"/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000056.jpg\",\n",
    "#     \"/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000086.jpg\",\n",
    "#     \"/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000096.jpg\",\n",
    "#     \"/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000126.jpg\",\n",
    "#     \"/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000156.jpg\",\n",
    "#     \"/datasets01/co3dv2/080422/teddybear/595_90395_180050/images/frame000186.jpg\",\n",
    "# ]\n",
    "\n",
    "\n",
    "# unseen obj: keyboard from co3d\n",
    "# filelist_test = [\n",
    "#     \"/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000096.jpg\",\n",
    "#     \"/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000126.jpg\",\n",
    "#     \"/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000156.jpg\",\n",
    "#     # \"/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000186.jpg\",\n",
    "#     # \"/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000006.jpg\",\n",
    "#     # \"/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000016.jpg\",\n",
    "#     # \"/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000026.jpg\",\n",
    "#     \"/datasets01/co3dv2/080422/keyboard/604_93822_187288/images/frame000036.jpg\",\n",
    "# ]\n",
    "\n",
    "# filelist_test = [\n",
    "#     \"/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000006.jpg\",\n",
    "#     \"/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000016.jpg\",\n",
    "#     \"/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000026.jpg\",\n",
    "#     \"/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000036.jpg\",\n",
    "#     \"/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000096.jpg\",\n",
    "#     \"/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000126.jpg\",\n",
    "#     \"/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000156.jpg\",\n",
    "#     \"/path/to/dust3r_data/co3d_50_seqs_per_category_subset_processed/keyboard/76_7733_16196/images/frame000186.jpg\",\n",
    "# ]\n",
    "\n",
    "# DTU\n",
    "# filelist_test = [\n",
    "#     \"/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_001_max.png\",\n",
    "#     \"/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_002_max.png\",\n",
    "#     \"/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_003_max.png\",\n",
    "#     \"/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_004_max.png\",\n",
    "#     \"/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_005_max.png\",\n",
    "#     \"/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_006_max.png\",\n",
    "#     \"/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_007_max.png\",\n",
    "#     \"/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Rectified/scan6/rect_008_max.png\",\n",
    "# ]\n",
    "\n",
    "\n",
    "# DTU test\n",
    "# filelist_test = [f\"/path/to/dust3r_data/dtu_test_mvsnet_release/scan4/images/000000{i:02d}.jpg\" for i in range(0, 49, 5)]\n",
    "# filelist_test = [f\"/path/to/dust3r_data/dtu_test_mvsnet_release/scan1/images/000000{i:02d}.jpg\" for i in range(0, 49, 5)]\n",
    "\n",
    "# NRGBD test\n",
    "# filelist_test = [f\"../data/neural_rgbd/kitchen/images/img{i}.png\" for i in range(1, 1517, 50)]\n",
    "# filelist_test = [f\"../data/neural_rgbd/morning_apartment/images/img{i}.png\" for i in range(1, 919, 30)]\n",
    "# filelist_test = [f\"../data/neural_rgbd/whiteroom/images/img{i}.png\" for i in range(1, 1675, 50)]\n",
    "# filelist_test = [f\"../data/neural_rgbd/grey_white_room/images/img{i}.png\" for i in range(1, 1492, 50)]\n",
    "# filelist_test = [f\"../data/neural_rgbd/green_room/images/img{i}.png\" for i in range(1, 1441, 100)]\n",
    "# filelist_test = [f\"../data/neural_rgbd/staircase/images/img{i}.png\" for i in range(0, 1148, 40)]\n",
    "\n",
    "# 7-Scenes test\n",
    "# filelist_test = [f\"/path/to/dust3r_data/7_scenes_processed/redkitchen/seq-06/frame-00{i:04d}.color.png\" for i in range(0, 1000, 50)]\n",
    "# filelist_test = filelist_test[0:10] * 50\n",
    "# filelist_test.pop(2)\n",
    "# filelist_test = [f\"/path/to/dust3r_data/7_scenes_processed/redkitchen/seq-06/frame-00{i:04d}.color.png\" for i in range(0, 1000, 2)][:88]\n",
    "# filelist_test = [f\"/path/to/dust3r_data/7_scenes_processed/redkitchen/seq-03/frame-00{i:04d}.color.png\" for i in range(0, 1000, 2)][:320]\n",
    "# filelist_test = [f\"/path/to/dust3r_data/7_scenes_processed/pumpkin/seq-02/frame-00{i:04d}.color.png\" for i in range(0, 1000, 20)]\n",
    "# filelist_test = [f\"/path/to/dust3r_data/7_scenes_processed/office/seq-09/frame-00{i:04d}.color.png\" for i in range(0, 1000, 20)]\n",
    "# filelist_test = [f\"/path/to/dust3r_data/7_scenes_processed/fire/seq-04/frame-00{i:04d}.color.png\" for i in range(0, 1000, 30)]\n",
    "\n",
    "\n",
    "# Tanks and Temples\n",
    "# use all images from /home/ssax/InstantSplat/data/collated_instantsplat_data/eval/Tanks/Family/24_views/dust3r_9_views/images by walking through the folder\n",
    "# filelist_test = []\n",
    "# for root, dirs, files in os.walk(\"/home/ssax/InstantSplat/data/collated_instantsplat_data/eval/Tanks/Family/24_views/dust3r_9_views/images\"):\n",
    "#     for file in files:\n",
    "#         filelist_test.append(os.path.join(root, file))\n",
    "# filelist_test = sorted(filelist_test) \n",
    "\n",
    "# filelist_test = [f\"/data/jianingy/tanks_and_temples/Barn/{i:06d}.jpg\" for i in range(1, 410, 1)]\n",
    "filelist_test = [f\"/data/jianingy/tanks_and_temples_subset/Barn/{i:06d}.jpg\" for i in range(1, 410, 2)]\n",
    "filelist_test = [f\"/data/jianingy/tanks_and_temples_subset/Lighthouse/{i:05d}.jpg\" for i in range(1, 309, 1)]\n",
    "filelist_test = [f\"/data/jianingy/tanks_and_temples_subset/Playground/{i:05d}.jpg\" for i in range(1, 307, 1)]\n",
    "filelist_test = [f\"/data/jianingy/tanks_and_temples_subset/Family/{i:05d}.jpg\" for i in range(1, 152, 50)]\n",
    "# filelist_test = [f\"/data/jianingy/tanks_and_temples/Courthouse/images/{i:08d}.jpg\" for i in range(1, 500, 1)]\n",
    "# filelist_test = [f\"/data/jianingy/tanks_and_temples/Ignatius/images/{i:08d}.jpg\" for i in range(1, 262, 1)]\n",
    "\n",
    "# RealEstate10K\n",
    "# randomly sample 10 files from /data/jianingy/RealEstate10K/videos/test/93e6c08c33206a0c\n",
    "# Randomly sample 10 files from the specified directory\n",
    "# def sample_random_files(directory, n=10):\n",
    "#     all_files = [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]\n",
    "#     return random.sample(all_files, n)\n",
    "\n",
    "# # Sample 10 random files from the RealEstate10K directory\n",
    "# filelist_test = sample_random_files(\"/data/jianingy/RealEstate10K/videos/test/924ccc02891cc7df\", 10)\n",
    "\n",
    "\n",
    "# # reverse the order\n",
    "# filelist_test = filelist_test[::-1]\n",
    "\n",
    "# filelist_test = [f\"/home/ssax/InstantSplat/data/collated_instantsplat_data/eval/Tanks/Barn/images/000{i:03d}.jpg\" for i in range(521, 670, 5)]\n",
    "\n",
    "\n",
    "# multi-cam dynamic scenes\n",
    "# Function to get sorted file list from a directory\n",
    "def get_sorted_file_list(directory):\n",
    "    return sorted([os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))])\n",
    "\n",
    "# filelist_test = [get_sorted_file_list(f\"/path/to/juggling_multicam/{cam_idx}/\")[149] for cam_idx in range(0, 8)]\n",
    "\n",
    "\n",
    "\n",
    "# display the images\n",
    "def display_images(filelist, title, rotate_clockwise_90=False, crop_to_landscape=False):\n",
    "    fig, axes = plt.subplots(1, len(filelist), figsize=(30, 4))\n",
    "    fig.suptitle(title)\n",
    "    for ax, filepath in zip(axes if hasattr(axes, '__iter__') else [axes], filelist):\n",
    "        img = Image.open(filepath)\n",
    "        if rotate_clockwise_90:\n",
    "            img = img.rotate(-90, expand=True)\n",
    "        if crop_to_landscape:\n",
    "            # Crop to a landscape aspect ratio (e.g., 16:9)\n",
    "            desired_aspect_ratio = 4 / 3\n",
    "            width, height = img.size\n",
    "            current_aspect_ratio = width / height\n",
    "\n",
    "            if current_aspect_ratio > desired_aspect_ratio:\n",
    "                # Wider than landscape: crop width\n",
    "                new_width = int(height * desired_aspect_ratio)\n",
    "                left = (width - new_width) // 2\n",
    "                right = left + new_width\n",
    "                top = 0\n",
    "                bottom = height\n",
    "            else:\n",
    "                # Taller than landscape: crop height\n",
    "                new_height = int(width / desired_aspect_ratio)\n",
    "                top = (height - new_height) // 2\n",
    "                bottom = top + new_height\n",
    "                left = 0\n",
    "                right = width\n",
    "            \n",
    "            img = img.crop((left, top, right, bottom))\n",
    "        \n",
    "        ax.imshow(img)\n",
    "        ax.axis('off')\n",
    "    plt.show()\n",
    "\n",
    "# # Display train images\n",
    "# display_images(filelist_train, 'Train Images')\n",
    "\n",
    "# Display test images\n",
    "display_images(filelist_test, 'Test Images')\n",
    "# display_images(filelist_test, 'Test Images', rotate_clockwise_90=True)\n",
    "# display_images(filelist_test, 'Test Images', rotate_clockwise_90=True, crop_to_landscape=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# skip this\n",
    "device = torch.device(\"cuda\")\n",
    "\n",
    "checkpoint_root = \"/path/to/checkpoint_root\"\n",
    "\n",
    "# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_demo_224_longer_epochs/checkpoint-best.pth').to(device)\n",
    "# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_demo_224_multiview/checkpoint-best.pth').to(device)\n",
    "# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_demo_224_multiview_co3d_full/checkpoint-last.pth').to(device)\n",
    "# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_demo_224_multiview_co3d_full_100_epochs_100_samples_per_window/checkpoint-last.pth').to(device)\n",
    "# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset/checkpoint-best.pth').to(device)\n",
    "# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_dec_and_head/checkpoint-last.pth').to(device)\n",
    "# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_everything/checkpoint-last.pth').to(device)\n",
    "# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_everything_large/checkpoint-last.pth').to(device)\n",
    "# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_everything_co3d_scannetpp_megadepth/checkpoint-last.pth').to(device)\n",
    "model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_everything_co3d_scannetpp_megadepth_large/checkpoint-last.pth').to(device)\n",
    "# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/dust3r_512_dpt_finetune_multiview_co3d_50_seqs_per_cat_subset_100_epochs_unfreeze_everything_co3d_scannetpp_megadepth_large/checkpoint-last.pth').to(device)\n",
    "# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/bf16_flash_attn_unfreeze_everything_co3d_scannetpp_megadepth_large_bs4/checkpoint-10.pth').to(device)\n",
    "# model = FlashDUSt3R.from_pretrained(f'{checkpoint_root}/bf16_flash_attn_unfreeze_everything_co3d_scannetpp_megadepth_large/checkpoint-last.pth').to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Lightning model\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict\n",
    "\n",
    "# device = torch.device(\"cuda:2\")\n",
    "device = torch.device(\"cuda\")\n",
    "\n",
    "checkpoint_root = \"path/to/checkpoint_root\"\n",
    "\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/train/runs/2024-08-13_04-40-37\"  #fp32-fancy-sun-181\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/train/runs/2024-08-13_08-06-08\"  #fp32_workers11_giddy-gorge-182\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4/runs/fp32_bs6_views4_3782640\"\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs2_views8/runs/fp32_bs2_views8_3782638\"\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4/runs/fp32_bs6_views4_4007485\"  # with random image idx embeddings\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4/runs/fp32_bs6_views4_4030983\"  # fix Regr3D loss (wrong rotation)\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4/runs/fp32_bs6_views4_4037511\"  # fix Regr3D loss (fixed rotation)\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4_scannetpp_only/runs/fp32_bs6_views4_scannetpp_only_4060428\"  # ScanNet++ only no random emb\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4_scannetpp_only/runs/fp32_bs6_views4_scannetpp_only_4051504\"  # ScanNet++ only\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fp32_bs6_views4_arkitscenes_only/runs/arkitscenes_only_4123064\"  # ARKitScenes only\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/arkitscenes_only_no_pairs/runs/arkitscenes_only_no_pairs_4129400\"  # ARKitScenes only no pairs\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/co3d_scannetpp/runs/co3d_scannetpp_4123062\"  # co3d_scannetpp\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/co3d_scannetpp_arkitscenes/runs/co3d_scannetpp_arkitscenes_4123063\"  # co3d_scannetpp_arkitscenes\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/co3d_scannetpp_arkitscenes_bs2_views8/runs/co3d_scannetpp_arkitscenes_bs2_views8_4155008\"  # co3d_scannetpp_arkitscenes 8 views\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/co3d_scannetpp_arkitscenes_better_random_pose_emb/runs/co3d_scannetpp_arkitscenes_better_random_pose_emb_4323524\"  # co3d_scannetpp_arkitscenes 8 views\n",
    "\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fast3r_co3d_scannetpp_arkitscenes_better_random_pose_emb/runs/fast3r_co3d_scannetpp_arkitscenes_better_random_pose_emb_4365927\"\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fast3r_habitat_larger_decoder_views4/runs/fast3r_habitat_larger_decoder_views4_4383740\"\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fast3r_habitat_larger_decoder_views8/runs/fast3r_habitat_larger_decoder_views8_4383741\"\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fast3r_larger_decoder_bs1_views4/runs/fast3r_larger_decoder_4371625\"\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fast3r_habitat_co3d_scannetpp_arkitscenes/runs/fast3r_habitat_co3d_scannetpp_arkitscenes_4383742\"\n",
    "\n",
    "# local head\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fast3r_no_local_head/runs/fast3r_no_local_head_4611636\"\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fast3r_no_local_head_habitat/runs/fast3r_no_local_head_habitat_4615832\"\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fast3r_local_head/runs/fast3r_local_head_4638120\"\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fast3r_local_head_habitat/runs/fast3r_local_head_habitat_4626119\"\n",
    "\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fast3r_local_head_habitat_better_scannetpp/runs/fast3r_local_head_habitat_better_scannetpp_4701731\"\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fast3r_local_head_habitat_better_scannetpp_8views/runs/fast3r_local_head_habitat_better_scannetpp_8views_4726417\"\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fast3r_local_head_better_scannetpp_and_arkit_pretrained_16views/runs/fast3r_local_head_better_scannetpp_and_arkit_pretrained_16views_4793676\"\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/fast3r_local_head_better_scannetpp_and_arkit_pretrained_20views_small_lr/runs/fast3r_local_head_better_scannetpp_and_arkit_pretrained_20views_small_lr_4804512\"\n",
    "\n",
    "\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/more_co3d_finetune_16views/runs/more_co3d_finetune_16views_4865625\"\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/more_co3d_finetune_20views/runs/more_co3d_finetune_20views_4867088\"\n",
    "\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/super_long_training/runs/super_long_training_5031318\"\n",
    "# checkpoint_dir = f\"{checkpoint_root}/dust3r/fast3r/logs/super_long_training/runs/super_long_training_5078043\"\n",
    "checkpoint_dir = f\"/data/jianingy/dust3r_data/fast3r_checkpoints/super_long_training_5175604\"\n",
    "\n",
    "\n",
    "print(\"Creating an empty lightning module to hold the weights...\")\n",
    "cfg = OmegaConf.load(os.path.join(checkpoint_dir, '.hydra/config.yaml'))\n",
    "\n",
    "# replace all occurances of \"dust3r.\" in cfg.model.net with \"fast3r.dust3r.\" (this is due to relocation of our code)\n",
    "def replace_dust3r_in_config(cfg):\n",
    "    for key, value in cfg.items():\n",
    "        if isinstance(value, DictConfig):\n",
    "            replace_dust3r_in_config(value)\n",
    "        elif isinstance(value, str):\n",
    "            if \"dust3r.\" in value and \"fast3r.dust3r.\" not in value:\n",
    "                cfg[key] = value.replace(\"dust3r.\", \"fast3r.dust3r.\")\n",
    "    return cfg\n",
    "\n",
    "def replace_src_in_config(cfg_dict):\n",
    "    for key, value in cfg_dict.items():\n",
    "        if isinstance(value, DictConfig):\n",
    "            replace_src_in_config(value)\n",
    "        elif isinstance(value, str) and \"src.\" in value:\n",
    "            cfg_dict[key] = value.replace(\"src.\", \"fast3r.\")\n",
    "    return cfg_dict\n",
    "\n",
    "cfg.model.net = replace_dust3r_in_config(cfg.model.net)\n",
    "cfg.model = replace_src_in_config(cfg.model)\n",
    "\n",
    "if \"encoder_args\" in cfg.model.net:\n",
    "    cfg.model.net.encoder_args.patch_embed_cls = \"PatchEmbedDust3R\"\n",
    "    cfg.model.net.head_args.landscape_only = False\n",
    "else:\n",
    "    cfg.model.net.patch_embed_cls = \"PatchEmbedDust3R\"  # TODO: investigate what exactly this does, this seems to support inferencing images of protrait orientation\n",
    "    cfg.model.net.landscape_only = False  # TODO: investigate what exactly this does\n",
    "\n",
    "\n",
    "cfg.model.net.decoder_args.random_image_idx_embedding = True # try to load the model without random image idx embeddings\n",
    "\n",
    "# enable attention biasing for inference more views than training\n",
    "cfg.model.net.decoder_args.attn_bias_for_inference_enabled = False\n",
    "\n",
    "lit_module = hydra.utils.instantiate(cfg.model, train_criterion=None, validation_criterion=None)\n",
    "\n",
    "\n",
    "print(\"Loading weights from checkpoint...\")\n",
    "\n",
    "\n",
    "# check if checkpoint_dir + \"/checkpoints/last.ckpt\" is a directory, if so, load the last checkpoint from that directory\n",
    "if os.path.isdir(checkpoint_dir + \"/checkpoints/last.ckpt\"):\n",
    "    # it is a DeepSpeed checkpoint, convert it to a regular checkpoint\n",
    "    CKPT_PATH = os.path.join(checkpoint_dir, 'checkpoints/last_aggregated.ckpt')\n",
    "    if not os.path.exists(CKPT_PATH):\n",
    "        convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir=checkpoint_dir + \"/checkpoints/last.ckpt\", output_file=CKPT_PATH, tag=None)\n",
    "else:\n",
    "    CKPT_PATH = os.path.join(checkpoint_dir, 'checkpoints/last.ckpt')\n",
    "\n",
    "lit_module = MultiViewDUSt3RLitModule.load_from_checkpoint(checkpoint_path=CKPT_PATH,\n",
    "                                                           net=lit_module.net,\n",
    "                                                           train_criterion=lit_module.train_criterion,\n",
    "                                                           validation_criterion=lit_module.validation_criterion,)\n",
    "lit_module.eval()\n",
    "model = lit_module.net.to(device)\n",
    "\n",
    "# model = torch.compile(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model.set_max_parallel_views_for_head(150) # set the maximum number of parallel views for the head\n",
    "model.set_max_parallel_views_for_head(20)\n",
    "\n",
    "output = get_reconstructed_scene(\n",
    "    outdir = \"./output\",\n",
    "    model = model,\n",
    "    device = device,\n",
    "    silent = False,\n",
    "    # image_size = 224,\n",
    "    image_size = 512,\n",
    "    filelist = filelist_test,\n",
    "    profiling=True,\n",
    "    dtype = torch.float32,\n",
    "    # dtype = torch.bfloat16,\n",
    ")\n",
    "\n",
    "\n",
    "# local to global alignment\n",
    "# before fix: lit_module.align_local_pts3d_to_global(preds=output['preds'], views=output['views'], min_conf_thr_percentile=0)\n",
    "lit_module.align_local_pts3d_to_global(preds=output['preds'], views=output['views'], min_conf_thr_percentile=85)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Camera pose evaluation on RealEstate10K\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import glob\n",
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "from PIL import Image\n",
    "\n",
    "from fast3r.dust3r.datasets.utils.transforms import ImgNorm\n",
    "from fast3r.dust3r.utils.geometry import inv\n",
    "from fast3r.dust3r.utils.image import imread_cv2\n",
    "import fast3r.dust3r.datasets.utils.cropping as cropping\n",
    "\n",
    "# Suppose these references exist in your environment:\n",
    "# - inference(...) function\n",
    "# - lit_module that has evaluate_camera_poses(...)\n",
    "# - model variable\n",
    "# - crop_resize_if_necessary(...) from your snippet\n",
    "\n",
    "# set random seed for reproducibility\n",
    "random.seed(42)\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "\n",
    "def crop_resize_if_necessary(\n",
    "    image, \n",
    "    intrinsics_3x3, \n",
    "    target_resolution=(512, 288),\n",
    "    rng=None, \n",
    "    info=None\n",
    "):\n",
    "    \"\"\"\n",
    "    1. Crops around the principal point so that the principal point stays near center.\n",
    "    2. Rescales to target_resolution (landscape 512×288 or swapped for portrait).\n",
    "    3. Updates the intrinsics accordingly.\n",
    "    \n",
    "    In this example, 'depthmap' is not used (we pass None), but\n",
    "    the logic is the same as your snippet for applying transformations.\n",
    "    \n",
    "    Args:\n",
    "      image: A PIL.Image or numpy array (H×W×3).\n",
    "      intrinsics_3x3: np.array (3,3) camera intrinsics (pixel-based).\n",
    "      target_resolution: (W_target, H_target) in landscape mode. \n",
    "                        If the image is portrait, we may swap them.\n",
    "      rng: numpy RandomState or None; used if you do random data augmentation or random orientation.\n",
    "      info: optional debug string\n",
    "\n",
    "    Returns:\n",
    "      image_out: PIL.Image resized to final shape\n",
    "      intrinsics_out: The updated 3×3 matrix\n",
    "    \"\"\"\n",
    "    # Convert to PIL if needed\n",
    "    if isinstance(image, np.ndarray):\n",
    "        image = Image.fromarray(image)\n",
    "\n",
    "    # Pull out info from intrinsics\n",
    "    # intrinsics_3x3[0,0] = fx, intrinsics_3x3[1,1] = fy,\n",
    "    # intrinsics_3x3[0,2] = cx, intrinsics_3x3[1,2] = cy\n",
    "    W_org, H_org = image.size\n",
    "    cx, cy = int(round(intrinsics_3x3[0,2])), int(round(intrinsics_3x3[1,2]))\n",
    "\n",
    "    # Basic check if principal point is not obviously invalid:\n",
    "    min_margin_x = min(cx, W_org - cx)\n",
    "    min_margin_y = min(cy, H_org - cy)\n",
    "    if min_margin_x < W_org / 5 or min_margin_y < H_org / 5:\n",
    "        # You might raise an error or do fallback\n",
    "        # for example just center-crop in the middle\n",
    "        pass\n",
    "\n",
    "    # Crop around the principal point, symmetrical in x & y\n",
    "    left   = cx - min_margin_x\n",
    "    top    = cy - min_margin_y\n",
    "    right  = cx + min_margin_x\n",
    "    bottom = cy + min_margin_y\n",
    "\n",
    "    crop_bbox = (left, top, right, bottom)\n",
    "    # For depthmap = None, we can pass None to the cropping utility\n",
    "    image_c, _, intrinsics_c = cropping.crop_image_depthmap(\n",
    "        image, \n",
    "        None, \n",
    "        intrinsics_3x3, \n",
    "        crop_bbox\n",
    "    )\n",
    "\n",
    "    # image_c is now a PIL.Image with size = (2*min_margin_x, 2*min_margin_y)\n",
    "    W_c, H_c = image_c.size\n",
    "\n",
    "    # Adjust target_resolution if the image is \"portrait\"\n",
    "    # e.g. if H > W. \n",
    "    # If your logic is to always produce 512×288 for \"landscape\" and 288×512 for \"portrait\":\n",
    "    # You can check aspect ratio:\n",
    "    if H_c > W_c:\n",
    "        # Swap if we need a \"portrait\" orientation\n",
    "        # (288×512 instead of 512×288)\n",
    "        target_resolution = (target_resolution[1], target_resolution[0])\n",
    "\n",
    "    # Now do a high-quality downscale (Lanczos)\n",
    "    # You can keep the same approach as your snippet or randomize if you do data augmentation\n",
    "    image_rs, _, intrinsics_rs = cropping.rescale_image_depthmap(\n",
    "        image_c, None, intrinsics_c, np.array(target_resolution)\n",
    "    )\n",
    "\n",
    "    # If there's still a small difference or if you do a final crop:\n",
    "    intrinsics2 = cropping.camera_matrix_of_crop(\n",
    "        intrinsics_rs, image_rs.size, target_resolution, offset_factor=0.5\n",
    "    )\n",
    "    final_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics_rs, intrinsics2, target_resolution)\n",
    "\n",
    "    image_out, _, intrinsics_out = cropping.crop_image_depthmap(\n",
    "        image_rs, None, intrinsics_rs, final_bbox\n",
    "    )\n",
    "\n",
    "    return image_out, intrinsics_out\n",
    "\n",
    "re10k_video_root = \"/data/jianingy/RealEstate10K/videos/test\"\n",
    "re10k_txt_root   = \"/data/jianingy/RealEstate10K/test\"\n",
    "\n",
    "# video_folders = sorted(os.listdir(re10k_video_root))\n",
    "# video_folders = ['9414231317ded453']\n",
    "# video_folders = ['0090cc64d7b7bb24']  # worst scene\n",
    "video_folders = ['0be9a0dcbfe032f1']  # worst scene\n",
    "\n",
    "for vid_folder in tqdm(video_folders, desc=\"Evaluating RealEstate10K Test Videos\"):\n",
    "    folder_path = os.path.join(re10k_video_root, vid_folder)\n",
    "    if not os.path.isdir(folder_path):\n",
    "        continue\n",
    "    \n",
    "    txt_path = os.path.join(re10k_txt_root, vid_folder + \".txt\")\n",
    "    if not os.path.exists(txt_path):\n",
    "        # no .txt => skip\n",
    "        continue\n",
    "    \n",
    "    # 1) Build a dict mapping \"frame ID\" => line columns\n",
    "    #    The first line of the .txt is the video URL, so skip it.\n",
    "    with open(txt_path, \"r\") as f:\n",
    "        txt_lines = f.read().strip().split(\"\\n\")\n",
    "    if len(txt_lines) <= 1:\n",
    "        continue\n",
    "    \n",
    "    txt_lines = txt_lines[1:]  # skip the URL line\n",
    "    # Create a dictionary like: lines_map[\"308641667\"] = [col1, col2, fx, fy, cx, cy, ...]\n",
    "    lines_map = {}\n",
    "    for line in txt_lines:\n",
    "        parts = line.strip().split()\n",
    "        if len(parts) < 19:  # In principle, should have 1 + 4 + 1 + 12 = 18 or 19 fields\n",
    "            continue\n",
    "        frame_id = parts[0]  # e.g. \"308641667\"\n",
    "        lines_map[frame_id] = parts  # entire line columns for that ID\n",
    "\n",
    "    # 2) Gather all JPG frames in this folder\n",
    "    frame_files = sorted(glob.glob(os.path.join(folder_path, \"*.jpg\")))\n",
    "    if len(frame_files) < 2:\n",
    "        continue\n",
    "\n",
    "    # 3) Sample a subset of frames\n",
    "    #    We can just sample from the actual files, \n",
    "    #    then look up the line by matching the base filename\n",
    "    n_to_sample = min(10, len(frame_files))\n",
    "    sampled_frames = random.sample(frame_files, n_to_sample)\n",
    "    # sampled_frames = frame_files[:n_to_sample]\n",
    "\n",
    "    # 4) Build \"views\" for each sampled frame\n",
    "    selected_views = []\n",
    "    for frame_path in sorted(sampled_frames):\n",
    "        # Extract \"308641667\" from \"308641667.jpg\"\n",
    "        basename = os.path.splitext(os.path.basename(frame_path))[0]\n",
    "        \n",
    "        # Check if we have a matching line in lines_map\n",
    "        if basename not in lines_map:\n",
    "            # No match => skip\n",
    "            # (This can happen if the .txt doesn't list every single frame or naming mismatch.)\n",
    "            continue\n",
    "        \n",
    "        columns = lines_map[basename]  # e.g. columns[1] => fx, columns[2] => fy, etc.\n",
    "        \n",
    "        # parse fx, fy, cx, cy\n",
    "        fx = float(columns[1])\n",
    "        fy = float(columns[2])\n",
    "        cx = float(columns[3])\n",
    "        cy = float(columns[4])\n",
    "        \n",
    "        # parse extrinsic (3×4), RE10K assumes row-major, where the translation is the last row\n",
    "        # 1) Parse the 3x4 (row-major) extrinsic values\n",
    "        #    columns[7:19] is exactly 12 floats\n",
    "        extrinsic_val = [float(v) for v in columns[7:19]]\n",
    "        extrinsic = np.array(extrinsic_val, dtype=np.float64).reshape(3, 4)\n",
    "\n",
    "        # 2) Build a 4x4 (row-major by default in NumPy)\n",
    "        pose_4x4 = np.eye(4, dtype=np.float32)\n",
    "        pose_4x4[:3, :3] = extrinsic[:3, :3]\n",
    "        pose_4x4[:3, -1] = extrinsic[:3, -1]\n",
    "        \n",
    "        poses_c2w_gt = inv(pose_4x4)\n",
    "\n",
    "        # read image\n",
    "        img_rgb = imread_cv2(frame_path)  # shape (H,W,3) in BGR or RGB depending on your function\n",
    "        if img_rgb is None:\n",
    "            continue\n",
    "        \n",
    "        H_org, W_org = img_rgb.shape[:2]\n",
    "        \n",
    "        # RealEstate10K formula: K = [[fx*W, 0, cx*W], [0, fy*H, cy*H], [0,0,1]]\n",
    "        K_3x3 = np.array([\n",
    "            [fx * W_org,        0.0,           cx * W_org],\n",
    "            [0.0,               fy * H_org,    cy * H_org],\n",
    "            [0.0,               0.0,           1.0       ],\n",
    "        ], dtype=np.float32)\n",
    "\n",
    "\n",
    "\n",
    "        # Convert to PIL (if imread_cv2 is BGR, also convert to RGB)\n",
    "        # e.g. if imread_cv2 returns BGR, do:\n",
    "        #   img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2RGB)\n",
    "        pil_img = Image.fromarray(img_rgb)\n",
    "\n",
    "        # Crop + resize around principal point => 512×288\n",
    "        final_img_pil, final_intrinsics_3x3 = crop_resize_if_necessary(\n",
    "            image=pil_img,\n",
    "            intrinsics_3x3=K_3x3,\n",
    "            target_resolution=(512, 288),\n",
    "            rng=np.random,\n",
    "            info=f\"{vid_folder}_{basename}\"\n",
    "        )\n",
    "\n",
    "        # Now normalize to [-1,1], channel-first\n",
    "        tensor_chw = ImgNorm(final_img_pil)  # shape (3,H,W) in [-1,1]\n",
    "        \n",
    "        # Store in a view dict\n",
    "        view_dict = {\n",
    "            \"img\": tensor_chw.unsqueeze(0),  # => (B=1,3,H,W)\n",
    "            \"camera_pose\": torch.from_numpy(poses_c2w_gt).unsqueeze(0),   # shape (1,4,4)\n",
    "            \"camera_intrinsics\": torch.from_numpy(final_intrinsics_3x3).unsqueeze(0),  # (1,3,3)\n",
    "            \"dataset\": [\"RealEstate10K\"],\n",
    "            \"true_shape\": torch.tensor([[final_img_pil.size[1], final_img_pil.size[0]]])  \n",
    "            # shape => (1, 2) = (height, width)\n",
    "        }\n",
    "        selected_views.append(view_dict)\n",
    "\n",
    "    # If we ended up with fewer than 2 views, skip\n",
    "    if len(selected_views) < 2:\n",
    "        continue\n",
    "\n",
    "    # 5) Run inference\n",
    "    output = inference(\n",
    "        selected_views,\n",
    "        model=model,\n",
    "        device=torch.device(\"cuda\"),\n",
    "        dtype=torch.float32,\n",
    "        verbose=False,\n",
    "        profiling=False\n",
    "    )\n",
    "\n",
    "    # 6) Evaluate camera poses\n",
    "    cam_pose_result = lit_module.evaluate_camera_poses(\n",
    "        views=output[\"views\"],\n",
    "        preds=output[\"preds\"],\n",
    "        niter_PnP=100,\n",
    "        focal_length_estimation_method='first_view_from_global_head'\n",
    "        # focal_length_estimation_method='first_view_from_local_head'\n",
    "    )[0]  # return a batch of results, we take the first one assuming batch size = 1\n",
    "    \n",
    "    # write cam pose result to a txt file, add a key of \"video_name\" to the result dict\n",
    "    cam_pose_result[\"video_name\"] = vid_folder\n",
    "    # save the result to a txt file\n",
    "    with open(f\"/home/jianingy/research/fast3r/notebooks/RealEstate10K_eval/{vid_folder}.txt\", \"w\") as f:\n",
    "        f.write(str(cam_pose_result))\n",
    "\n",
    "    # lit_module.align_local_pts3d_to_global(preds=output['preds'], views=output['views'], min_conf_thr_percentile=85)\n",
    "\n",
    "\n",
    "print(\"All done!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "for timestep in tqdm(range(149)):\n",
    "\n",
    "    filelist_test = [get_sorted_file_list(f\"/path/to/juggling_multicam/{cam_idx}/\")[149] for cam_idx in range(0, 8)]\n",
    "    model.set_max_parallel_views_for_head(150) # set the maximum number of parallel views for the head\n",
    "\n",
    "    output = get_reconstructed_scene(\n",
    "        outdir = \"./output\",\n",
    "        model = model,\n",
    "        device = device,\n",
    "        silent = False,\n",
    "        # image_size = 224,\n",
    "        image_size = 512,\n",
    "        filelist = filelist_test,\n",
    "        profiling=True,\n",
    "        dtype = torch.float32,\n",
    "    )\n",
    "\n",
    "\n",
    "    # local to global alignment\n",
    "    # before fix: lit_module.align_local_pts3d_to_global(preds=output['preds'], views=output['views'], min_conf_thr_percentile=0)\n",
    "    lit_module.align_local_pts3d_to_global(preds=output['preds'], views=output['views'], min_conf_thr_percentile=85)\n",
    "    \n",
    "    img_output_dir = f\"./output/juggling_multicam/{timestep}\"\n",
    "    dirs_to_create = [img_output_dir, os.path.join(img_output_dir, \"rgb_images\"), os.path.join(img_output_dir, \"global_confidence_maps\"), os.path.join(img_output_dir, \"local_depth_and_confidence_maps\")]\n",
    "\n",
    "    if os.path.exists(img_output_dir):\n",
    "        shutil.rmtree(img_output_dir)\n",
    "\n",
    "    for d in dirs_to_create:\n",
    "        if not os.path.exists(d):\n",
    "            os.makedirs(d)\n",
    "\n",
    "    # Usage example in your context\n",
    "    # Plot the RGB images\n",
    "    plot_rgb_images(output['views'], save_image_to_folder=os.path.join(img_output_dir, \"rgb_images\"))\n",
    "\n",
    "    # Plot the confidence maps\n",
    "    plot_confidence_maps(output['preds'], save_image_to_folder=os.path.join(img_output_dir, \"global_confidence_maps\"))\n",
    "\n",
    "    # Plot the local depth and confidence maps\n",
    "    maybe_plot_local_depth_and_conf(output['preds'], save_image_to_folder=os.path.join(img_output_dir, \"local_depth_and_confidence_maps\"))\n",
    "\n",
    "    export_combined_ply(\n",
    "        preds=output['preds'],\n",
    "        views=output['views'],\n",
    "        pts3d_key_to_visualize=\"pts3d_local_aligned_to_global\",\n",
    "        conf_key_to_visualize=\"conf_local\",\n",
    "        export_ply_path=os.path.join(img_output_dir, \"combined_pointcloud.ply\"),\n",
    "        min_conf_thr_percentile=45,\n",
    "        flip_axes=True,\n",
    "        max_num_points=1_000_000,          # Set your desired maximum number of points here\n",
    "        sampling_strategy='uniform'    # Choose 'uniform', 'voxel', or 'farthest_point'\n",
    "    )\n",
    "\n",
    "    save_pointmaps_and_camera_parameters_to_folder(preds=output['preds'], save_folder=img_output_dir, niter_PnP=100, focal_length_estimation_method='first_view_from_global_head')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# low conf views: [13:24]\n",
    "conf_list = [pred['conf'] for pred in output['preds']]\n",
    "conf = torch.stack(conf_list, dim=0).squeeze(1)\n",
    "conf[1].max()\n",
    "\n",
    "# plot a bar chart of the confidence scores (max value per view)\n",
    "# conf has shape [num_views, H, W]\n",
    "\n",
    "# Step 1: Extract the maximum confidence score for each view\n",
    "max_conf_per_view = conf.view(conf.shape[0], -1).max(dim=1).values\n",
    "\n",
    "# Step 2: Plot the bar chart\n",
    "plt.figure(figsize=(25, 4))\n",
    "plt.bar(range(len(max_conf_per_view)), max_conf_per_view.numpy())\n",
    "\n",
    "# draw a horizontal red dotted line at the 1.5 threshold\n",
    "plt.axhline(y=1.5, color='r', linestyle='--')\n",
    "\n",
    "plt.xlabel('View Index')\n",
    "plt.ylabel('Max Confidence Score')\n",
    "plt.title('Max Confidence Score per View')\n",
    "plt.show()\n",
    "\n",
    "# print number of views vs. total number of views with confidence score > 1.5\n",
    "print(f\"Number of views with confidence score > 1.5: {torch.sum(max_conf_per_view > 1.5)} out of {len(max_conf_per_view)}\")\n",
    "\n",
    "# print the average and median confidence score\n",
    "print(f\"Average confidence score: {torch.mean(max_conf_per_view)}\")\n",
    "print(f\"Median confidence score: {torch.median(max_conf_per_view)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %load_ext autoreload\n",
    "# %autoreload 2\n",
    "\n",
    "\n",
    "img_output_dir = \"./output/nrgbd_kitchen\"\n",
    "dirs_to_create = [img_output_dir, os.path.join(img_output_dir, \"rgb_images\"), os.path.join(img_output_dir, \"global_confidence_maps\"), os.path.join(img_output_dir, \"local_depth_and_confidence_maps\")]\n",
    "\n",
    "if os.path.exists(img_output_dir):\n",
    "    shutil.rmtree(img_output_dir)\n",
    "\n",
    "for d in dirs_to_create:\n",
    "    if not os.path.exists(d):\n",
    "        os.makedirs(d)\n",
    "\n",
    "# Usage example in your context\n",
    "# Plot the RGB images\n",
    "plot_rgb_images(output['views'], save_image_to_folder=os.path.join(img_output_dir, \"rgb_images\"))\n",
    "\n",
    "# Plot the confidence maps\n",
    "plot_confidence_maps(output['preds'], save_image_to_folder=os.path.join(img_output_dir, \"global_confidence_maps\"))\n",
    "\n",
    "# Plot the local depth and confidence maps\n",
    "maybe_plot_local_depth_and_conf(output['preds'], save_image_to_folder=os.path.join(img_output_dir, \"local_depth_and_confidence_maps\"))\n",
    "\n",
    "export_combined_ply(\n",
    "    preds=output['preds'],\n",
    "    views=output['views'],\n",
    "    pts3d_key_to_visualize=\"pts3d_local_aligned_to_global\",\n",
    "    conf_key_to_visualize=\"conf_local\",\n",
    "    export_ply_path=os.path.join(img_output_dir, \"combined_pointcloud.ply\"),\n",
    "    min_conf_thr_percentile=15,\n",
    "    flip_axes=True,\n",
    "    max_num_points=1_000_000,          # Set your desired maximum number of points here\n",
    "    sampling_strategy='uniform'    # Choose 'uniform', 'voxel', or 'farthest_point'\n",
    ")\n",
    "\n",
    "save_pointmaps_and_camera_parameters_to_folder(preds=output['preds'], save_folder=img_output_dir, niter_PnP=100, focal_length_estimation_method='first_view_from_global_head')\n",
    "\n",
    "# Plot the 3D points along with estimated camera poses\n",
    "# plot_3d_points_with_estimated_camera_poses(\n",
    "#     output['preds'],  # Predictions containing 3D points\n",
    "#     output['views'],  # Views containing RGB images\n",
    "#     flip_axes=True,   # Enable flipping of axes (swap Y and Z and flip Z)\n",
    "#     min_conf_thr_percentile=0,  # Confidence threshold percentile for filtering points\n",
    "#     # export_ply_path='./output/combined_mesh.ply'  # Export path for the .ply file\n",
    "#     export_html_path='./output/combined_mesh.html'  # Export path for the .html file\n",
    "# )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "server.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# viser visualization\n",
    "\n",
    "import time\n",
    "import threading\n",
    "import numpy as np\n",
    "from tqdm.auto import tqdm\n",
    "import imageio.v3 as iio\n",
    "from matplotlib import cm\n",
    "\n",
    "import viser\n",
    "import viser.transforms as tf\n",
    "from fast3r.dust3r.utils.device import to_numpy\n",
    "\n",
    "def start_visualization(output, min_conf_thr_percentile=10, global_conf_thr_value_to_drop_view=1.5, port=8020):\n",
    "    # Create the viser server on the specified port\n",
    "    server = viser.ViserServer(host='127.0.0.1', port=port)\n",
    "\n",
    "    # Estimate camera poses\n",
    "    poses_c2w_batch, estimated_focals = MultiViewDUSt3RLitModule.estimate_camera_poses(\n",
    "        output['preds'], niter_PnP=100, focal_length_estimation_method='first_view_from_global_head'\n",
    "    )\n",
    "    poses_c2w = poses_c2w_batch[0]  # Assuming batch size of 1\n",
    "\n",
    "    # Set the upward direction to negative Y-axis\n",
    "    server.scene.set_up_direction((0.0, -1.0, 0.0))\n",
    "    server.scene.world_axes.visible = False  # Optional: Hide world axes\n",
    "\n",
    "    num_frames = len(output['preds'])\n",
    "\n",
    "    # Prepare lists to store per-frame data\n",
    "    frame_data_list = []\n",
    "\n",
    "    # Generate colors for frustums and points in rainbow order\n",
    "    def rainbow_color(n, total):\n",
    "        import colorsys\n",
    "        hue = n / total\n",
    "        rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0)\n",
    "        return rgb\n",
    "\n",
    "    # Add playback UI\n",
    "    with server.gui.add_folder(\"Playback\"):\n",
    "        gui_point_size = server.gui.add_slider(\"Point size\", min=0.000001, max=0.002, step=1e-5, initial_value=0.0005)\n",
    "        gui_frustum_size_percent = server.gui.add_slider(\"Camera Size (%)\", min=0.1, max=10.0, step=0.1, initial_value=2.0)\n",
    "        gui_timestep = server.gui.add_slider(\"Timestep\", min=0, max=num_frames - 1, step=1, initial_value=0, disabled=True)\n",
    "        gui_next_frame = server.gui.add_button(\"Next Frame\", disabled=True)\n",
    "        gui_prev_frame = server.gui.add_button(\"Prev Frame\", disabled=True)\n",
    "        gui_playing = server.gui.add_checkbox(\"Playing\", True)\n",
    "        gui_framerate = server.gui.add_slider(\"FPS\", min=0.25, max=60, step=0.25, initial_value=10)\n",
    "        gui_framerate_options = server.gui.add_button_group(\"FPS options\", (\"0.5\", \"1\", \"10\", \"20\", \"30\", \"60\"))\n",
    "\n",
    "    # Add point cloud options UI\n",
    "    with server.gui.add_folder(\"Point Cloud Options\"):\n",
    "        gui_show_global = server.gui.add_checkbox(\"Global\", False)\n",
    "        gui_show_local = server.gui.add_checkbox(\"Local\", True)\n",
    "\n",
    "    # Add view options UI\n",
    "    with server.gui.add_folder(\"View Options\"):\n",
    "        gui_show_high_conf = server.gui.add_checkbox(\"Show High-Conf Views\", True)\n",
    "        gui_show_low_conf = server.gui.add_checkbox(\"Show Low-Conf Views\", False)\n",
    "        gui_global_conf_threshold = server.gui.add_slider(\"High/Low Conf threshold value\", min=1.0, max=12.0, step=0.1, initial_value=global_conf_thr_value_to_drop_view)\n",
    "        gui_min_conf_percentile = server.gui.add_slider(\"Per-View conf percentile\", min=0, max=100, step=1, initial_value=min_conf_thr_percentile)\n",
    "\n",
    "    # Add color options UI\n",
    "    with server.gui.add_folder(\"Color Options\"):\n",
    "        gui_show_confidence = server.gui.add_checkbox(\"Show Confidence\", False)\n",
    "        gui_rainbow_color = server.gui.add_checkbox(\"Rainbow Colors\", False)\n",
    "\n",
    "    button_render_gif = server.gui.add_button(\"Render a GIF\")\n",
    "\n",
    "    # Frame step buttons\n",
    "    @gui_next_frame.on_click\n",
    "    def _(_) -> None:\n",
    "        gui_timestep.value = (gui_timestep.value + 1) % num_frames\n",
    "\n",
    "    @gui_prev_frame.on_click\n",
    "    def _(_) -> None:\n",
    "        gui_timestep.value = (gui_timestep.value - 1) % num_frames\n",
    "\n",
    "    # Disable frame controls when we're playing\n",
    "    @gui_playing.on_update\n",
    "    def _(_) -> None:\n",
    "        gui_timestep.disabled = gui_playing.value\n",
    "        gui_next_frame.disabled = gui_playing.value\n",
    "        gui_prev_frame.disabled = gui_playing.value\n",
    "\n",
    "    # Set the framerate when we click one of the options\n",
    "    @gui_framerate_options.on_click\n",
    "    def _(_) -> None:\n",
    "        gui_framerate.value = float(gui_framerate_options.value)\n",
    "\n",
    "    server.scene.add_frame(\"/cams\", show_axes=False)\n",
    "\n",
    "    # First pass: Collect data and compute scene extent\n",
    "    cumulative_pts = []\n",
    "\n",
    "    for i in tqdm(range(num_frames)):\n",
    "        pred = output['preds'][i]\n",
    "        view = output['views'][i]\n",
    "\n",
    "        # Extract global and local points and confidences\n",
    "        pts3d_global = to_numpy(pred['pts3d_in_other_view'].cpu().squeeze())\n",
    "        conf_global = to_numpy(pred['conf'].cpu().squeeze())\n",
    "        pts3d_local = to_numpy(pred['pts3d_local_aligned_to_global'].cpu().squeeze())\n",
    "        conf_local = to_numpy(pred['conf_local'].cpu().squeeze())\n",
    "        img_rgb = to_numpy(view['img'].cpu().squeeze().permute(1, 2, 0))\n",
    "\n",
    "        # Reshape and flatten data\n",
    "        pts3d_global = pts3d_global.reshape(-1, 3)\n",
    "        pts3d_local = pts3d_local.reshape(-1, 3)\n",
    "        img_rgb = img_rgb.reshape(-1, 3)\n",
    "        conf_global = conf_global.flatten()\n",
    "        conf_local = conf_local.flatten()\n",
    "\n",
    "        cumulative_pts.append(pts3d_global)\n",
    "\n",
    "        # Store per-frame data\n",
    "        frame_data = {}\n",
    "\n",
    "        # Sort points by confidence in descending order\n",
    "        # For global point cloud\n",
    "        sort_indices_global = np.argsort(-conf_global)\n",
    "        sorted_conf_global = conf_global[sort_indices_global]\n",
    "        sorted_pts3d_global = pts3d_global[sort_indices_global]\n",
    "        sorted_img_rgb_global = img_rgb[sort_indices_global]\n",
    "\n",
    "        # For local point cloud\n",
    "        sort_indices_local = np.argsort(-conf_local)\n",
    "        sorted_conf_local = conf_local[sort_indices_local]\n",
    "        sorted_pts3d_local = pts3d_local[sort_indices_local]\n",
    "        sorted_img_rgb_local = img_rgb[sort_indices_local]\n",
    "\n",
    "        # Normalize colors\n",
    "        colors_rgb_global = ((sorted_img_rgb_global + 1) * 127.5).astype(np.uint8) / 255.0  # Values in [0,1]\n",
    "        colors_rgb_local = ((sorted_img_rgb_local + 1) * 127.5).astype(np.uint8) / 255.0  # Values in [0,1]\n",
    "\n",
    "        # Precompute confidence-based colors\n",
    "        conf_norm_global = (sorted_conf_global - sorted_conf_global.min()) / (sorted_conf_global.max() - sorted_conf_global.min() + 1e-8)\n",
    "        conf_norm_local = (sorted_conf_local - sorted_conf_local.min()) / (sorted_conf_local.max() - sorted_conf_local.min() + 1e-8)\n",
    "        colormap = cm.turbo\n",
    "        colors_confidence_global = colormap(conf_norm_global)[:, :3]  # Values in [0,1]\n",
    "        colors_confidence_local = colormap(conf_norm_local)[:, :3]  # Values in [0,1]\n",
    "\n",
    "        # Rainbow color for the frame's points\n",
    "        rainbow_color_for_frame = rainbow_color(i, num_frames)\n",
    "        colors_rainbow_global = np.tile(rainbow_color_for_frame, (sorted_pts3d_global.shape[0], 1))\n",
    "        colors_rainbow_local = np.tile(rainbow_color_for_frame, (sorted_pts3d_local.shape[0], 1))\n",
    "\n",
    "        # Compute initial high-confidence flag based on global confidence\n",
    "        max_conf_global = conf_global.max()\n",
    "        is_high_confidence = max_conf_global >= gui_global_conf_threshold.value\n",
    "\n",
    "        # Camera parameters\n",
    "        c2w = poses_c2w[i]\n",
    "        height, width = view['img'].shape[2], view['img'].shape[3]\n",
    "        focal_length = estimated_focals[0][i]\n",
    "        img_rgb_reshaped = img_rgb.reshape(height, width, 3)\n",
    "        img_rgb_normalized = ((img_rgb_reshaped + 1) * 127.5).astype(np.uint8)  # Values in [0,255]\n",
    "        img_downsampled = img_rgb_normalized[::4, ::4]  # Keep as uint8\n",
    "\n",
    "        # Store all precomputed data\n",
    "        frame_data['sorted_pts3d_global'] = sorted_pts3d_global\n",
    "        frame_data['colors_rgb_global'] = colors_rgb_global\n",
    "        frame_data['colors_confidence_global'] = colors_confidence_global\n",
    "        frame_data['colors_rainbow_global'] = colors_rainbow_global\n",
    "\n",
    "        frame_data['sorted_pts3d_local'] = sorted_pts3d_local\n",
    "        frame_data['colors_rgb_local'] = colors_rgb_local\n",
    "        frame_data['colors_confidence_local'] = colors_confidence_local\n",
    "        frame_data['colors_rainbow_local'] = colors_rainbow_local\n",
    "\n",
    "        frame_data['max_conf_global'] = max_conf_global\n",
    "        frame_data['is_high_confidence'] = is_high_confidence\n",
    "\n",
    "        frame_data['c2w'] = c2w\n",
    "        frame_data['height'] = height\n",
    "        frame_data['width'] = width\n",
    "        frame_data['focal_length'] = focal_length\n",
    "        frame_data['img_downsampled'] = img_downsampled\n",
    "        frame_data['rainbow_color'] = rainbow_color_for_frame\n",
    "\n",
    "        frame_data_list.append(frame_data)\n",
    "\n",
    "    # Compute scene extent and max_extent\n",
    "    cumulative_pts_combined = np.concatenate(cumulative_pts, axis=0)\n",
    "    min_coords = np.min(cumulative_pts_combined, axis=0)\n",
    "    max_coords = np.max(cumulative_pts_combined, axis=0)\n",
    "    scene_extent = max_coords - min_coords\n",
    "    max_extent = np.max(scene_extent)\n",
    "\n",
    "    # Now create the visualization nodes\n",
    "    for i in tqdm(range(num_frames)):\n",
    "        frame_data = frame_data_list[i]\n",
    "\n",
    "        # Initialize frame node\n",
    "        frame_node = server.scene.add_frame(f\"/cams/t{i}\", show_axes=False)\n",
    "\n",
    "        # Initialize point cloud nodes\n",
    "        # Global point cloud\n",
    "        point_node_global = server.scene.add_point_cloud(\n",
    "            name=f\"/pts3d_global/t{i}\",\n",
    "            points=frame_data['sorted_pts3d_global'],\n",
    "            colors=frame_data['colors_rgb_global'],\n",
    "            point_size=gui_point_size.value,\n",
    "            point_shape=\"rounded\",\n",
    "            visible=False,  # Initially hidden\n",
    "        )\n",
    "\n",
    "        # Local point cloud\n",
    "        point_node_local = server.scene.add_point_cloud(\n",
    "            name=f\"/pts3d_local/t{i}\",\n",
    "            points=frame_data['sorted_pts3d_local'],\n",
    "            colors=frame_data['colors_rgb_local'],\n",
    "            point_size=gui_point_size.value,\n",
    "            point_shape=\"rounded\",\n",
    "            visible=True if frame_data_list[i]['is_high_confidence'] else False,\n",
    "        )\n",
    "\n",
    "        # Compute frustum parameters\n",
    "        c2w = frame_data['c2w']\n",
    "        rotation_matrix = c2w[:3, :3]\n",
    "        position = c2w[:3, 3]\n",
    "        rotation_quaternion = tf.SO3.from_matrix(rotation_matrix).wxyz\n",
    "\n",
    "        fov = 2 * np.arctan2(frame_data['height'] / 2, frame_data['focal_length'])\n",
    "        aspect_ratio = frame_data['width'] / frame_data['height']\n",
    "        frustum_scale = max_extent * (gui_frustum_size_percent.value / 100.0)\n",
    "\n",
    "        frustum_node = server.scene.add_camera_frustum(\n",
    "            name=f\"/cams/t{i}/frustum\",\n",
    "            fov=fov,\n",
    "            aspect=aspect_ratio,\n",
    "            scale=frustum_scale,\n",
    "            color=frame_data['rainbow_color'],\n",
    "            image=frame_data['img_downsampled'],\n",
    "            wxyz=rotation_quaternion,\n",
    "            position=position,\n",
    "            visible=True if frame_data_list[i]['is_high_confidence'] else False,\n",
    "        )\n",
    "\n",
    "        # Store nodes\n",
    "        frame_data['frame_node'] = frame_node\n",
    "        frame_data['point_node_global'] = point_node_global\n",
    "        frame_data['point_node_local'] = point_node_local\n",
    "        frame_data['frustum_node'] = frustum_node\n",
    "\n",
    "    # Set initial visibility\n",
    "    for frame_data in frame_data_list:\n",
    "        frame_data['frame_node'].visible = False\n",
    "        frame_data['point_node_global'].visible = False\n",
    "        frame_data['point_node_local'].visible = False\n",
    "        frame_data['frustum_node'].visible = False\n",
    "\n",
    "    def update_visibility():\n",
    "        current_timestep = int(gui_timestep.value)\n",
    "        with server.atomic():\n",
    "            for i in range(num_frames):\n",
    "                frame_data = frame_data_list[i]\n",
    "                if i <= current_timestep:\n",
    "                    is_high_confidence = frame_data['is_high_confidence']\n",
    "                    show_frame = False\n",
    "                    if is_high_confidence and gui_show_high_conf.value:\n",
    "                        show_frame = True\n",
    "                    if not is_high_confidence and gui_show_low_conf.value:\n",
    "                        show_frame = True\n",
    "\n",
    "                    # Update visibility based on global point cloud confidence\n",
    "                    frame_data['frame_node'].visible = show_frame\n",
    "                    frame_data['frustum_node'].visible = show_frame\n",
    "\n",
    "                    # Show/hide global point cloud\n",
    "                    frame_data['point_node_global'].visible = show_frame and gui_show_global.value\n",
    "\n",
    "                    # Show/hide local point cloud\n",
    "                    frame_data['point_node_local'].visible = show_frame and gui_show_local.value\n",
    "                else:\n",
    "                    frame_data['frame_node'].visible = False\n",
    "                    frame_data['frustum_node'].visible = False\n",
    "                    frame_data['point_node_global'].visible = False\n",
    "                    frame_data['point_node_local'].visible = False\n",
    "        server.flush()\n",
    "\n",
    "    @gui_timestep.on_update\n",
    "    def _(_) -> None:\n",
    "        update_visibility()\n",
    "\n",
    "    @gui_point_size.on_update\n",
    "    def _(_) -> None:\n",
    "        with server.atomic():\n",
    "            for frame_data in frame_data_list:\n",
    "                frame_data['point_node_global'].point_size = gui_point_size.value\n",
    "                frame_data['point_node_local'].point_size = gui_point_size.value\n",
    "        server.flush()\n",
    "\n",
    "    @gui_frustum_size_percent.on_update\n",
    "    def _(_) -> None:\n",
    "        frustum_scale = max_extent * (gui_frustum_size_percent.value / 100.0)\n",
    "        with server.atomic():\n",
    "            for frame_data in frame_data_list:\n",
    "                frame_data['frustum_node'].scale = frustum_scale\n",
    "        server.flush()\n",
    "\n",
    "    @gui_show_confidence.on_update\n",
    "    def _(_) -> None:\n",
    "        update_point_cloud_colors()\n",
    "\n",
    "    @gui_rainbow_color.on_update\n",
    "    def _(_) -> None:\n",
    "        update_point_cloud_colors()\n",
    "\n",
    "    @gui_show_global.on_update\n",
    "    def _(_) -> None:\n",
    "        update_visibility()\n",
    "\n",
    "    @gui_show_local.on_update\n",
    "    def _(_) -> None:\n",
    "        update_visibility()\n",
    "\n",
    "    def update_point_cloud_colors():\n",
    "        with server.atomic():\n",
    "            for frame_data in frame_data_list:\n",
    "                num_points_to_show_global = frame_data.get('num_points_to_show_global', len(frame_data['sorted_pts3d_global']))\n",
    "                num_points_to_show_local = frame_data.get('num_points_to_show_local', len(frame_data['sorted_pts3d_local']))\n",
    "\n",
    "                # Update global point cloud colors\n",
    "                if gui_show_confidence.value:\n",
    "                    colors_global = frame_data['colors_confidence_global'][:num_points_to_show_global]\n",
    "                elif gui_rainbow_color.value:\n",
    "                    colors_global = frame_data['colors_rainbow_global'][:num_points_to_show_global]\n",
    "                else:\n",
    "                    colors_global = frame_data['colors_rgb_global'][:num_points_to_show_global]\n",
    "                frame_data['point_node_global'].colors = colors_global\n",
    "\n",
    "                # Update local point cloud colors\n",
    "                if gui_show_confidence.value:\n",
    "                    colors_local = frame_data['colors_confidence_local'][:num_points_to_show_local]\n",
    "                elif gui_rainbow_color.value:\n",
    "                    colors_local = frame_data['colors_rainbow_local'][:num_points_to_show_local]\n",
    "                else:\n",
    "                    colors_local = frame_data['colors_rgb_local'][:num_points_to_show_local]\n",
    "                frame_data['point_node_local'].colors = colors_local\n",
    "        server.flush()\n",
    "\n",
    "    @gui_show_high_conf.on_update\n",
    "    def _(_) -> None:\n",
    "        update_visibility()\n",
    "\n",
    "    @gui_show_low_conf.on_update\n",
    "    def _(_) -> None:\n",
    "        update_visibility()\n",
    "\n",
    "    @gui_global_conf_threshold.on_update\n",
    "    def _(_) -> None:\n",
    "        # Update high-confidence flags based on new threshold\n",
    "        for frame_data in frame_data_list:\n",
    "            is_high_confidence = frame_data['max_conf_global'] >= gui_global_conf_threshold.value\n",
    "            frame_data['is_high_confidence'] = is_high_confidence\n",
    "        update_visibility()\n",
    "\n",
    "    @gui_min_conf_percentile.on_update\n",
    "    def _(_) -> None:\n",
    "        # Update number of points to display based on percentile\n",
    "        percentile = gui_min_conf_percentile.value\n",
    "        with server.atomic():\n",
    "            for frame_data in frame_data_list:\n",
    "                # For global point cloud\n",
    "                total_points_global = len(frame_data['sorted_pts3d_global'])\n",
    "                num_points_to_show_global = int(total_points_global * (100 - percentile) / 100)\n",
    "                num_points_to_show_global = max(1, num_points_to_show_global)  # Ensure at least one point\n",
    "                frame_data['num_points_to_show_global'] = num_points_to_show_global\n",
    "                frame_data['point_node_global'].points = frame_data['sorted_pts3d_global'][:num_points_to_show_global]\n",
    "\n",
    "                # For local point cloud\n",
    "                total_points_local = len(frame_data['sorted_pts3d_local'])\n",
    "                num_points_to_show_local = int(total_points_local * (100 - percentile) / 100)\n",
    "                num_points_to_show_local = max(1, num_points_to_show_local)  # Ensure at least one point\n",
    "                frame_data['num_points_to_show_local'] = num_points_to_show_local\n",
    "                frame_data['point_node_local'].points = frame_data['sorted_pts3d_local'][:num_points_to_show_local]\n",
    "\n",
    "            # Update colors\n",
    "            update_point_cloud_colors()\n",
    "        server.flush()\n",
    "\n",
    "    def playback_loop():\n",
    "        while True:\n",
    "            if gui_playing.value:\n",
    "                gui_timestep.value = (int(gui_timestep.value) + 1) % num_frames\n",
    "            time.sleep(1.0 / gui_framerate.value)\n",
    "\n",
    "    playback_thread = threading.Thread(target=playback_loop)\n",
    "    playback_thread.start()\n",
    "\n",
    "    @button_render_gif.on_click\n",
    "    def _(event: viser.GuiEvent) -> None:\n",
    "        client = event.client\n",
    "        if client is None:\n",
    "            print(\"Error: No client connected.\")\n",
    "            return\n",
    "        try:\n",
    "            images = []\n",
    "            original_timestep = gui_timestep.value\n",
    "            original_playing = gui_playing.value\n",
    "            gui_playing.value = False\n",
    "            fps = gui_framerate.value\n",
    "            for i in range(num_frames):\n",
    "                gui_timestep.value = i\n",
    "                time.sleep(0.1)\n",
    "                image = client.get_render(height=720, width=1280)\n",
    "                images.append(image)\n",
    "            gif_bytes = iio.imwrite(\"<bytes>\", images, extension=\".gif\", fps=fps, loop=0)\n",
    "            client.send_file_download(\"visualization.gif\", gif_bytes)\n",
    "            gui_timestep.value = original_timestep\n",
    "            gui_playing.value = original_playing\n",
    "        except Exception as e:\n",
    "            print(f\"Error while rendering GIF: {e}\")\n",
    "\n",
    "    print(f\"Visualization setup complete. Access the viser server at http://localhost:{port}\")\n",
    "    public_url = server.request_share_url()\n",
    "    print(f\"Public URL: {public_url}\")\n",
    "    return server\n",
    "\n",
    "# Start the visualization server\n",
    "server = start_visualization(\n",
    "    output=output,\n",
    "    min_conf_thr_percentile=10,\n",
    "    global_conf_thr_value_to_drop_view=1.5,\n",
    "    port=8020\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "server.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save the images to jpgs\n",
    "for i, img in enumerate(output['views']):\n",
    "    img = img['img'][0]\n",
    "    img = img.permute(1, 2, 0).cpu().numpy()\n",
    "    img = ((img + 1) * 127.5).astype(np.uint8)\n",
    "    img = Image.fromarray(img)\n",
    "    img.save(f\"./output/img_{i}.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from fast3r.dust3r.utils.device import to_numpy\n",
    "import pyrender\n",
    "\n",
    "# Set the EGL platform for offscreen rendering\n",
    "os.environ[\"PYOPENGL_PLATFORM\"] = \"egl\"\n",
    "\n",
    "def create_camera_pose(camera_position, target_point, up_vector):\n",
    "    \"\"\"\n",
    "    Create a camera pose matrix (camera-to-world) that positions the camera at camera_position\n",
    "    and orients it to look at target_point.\n",
    "    \"\"\"\n",
    "    # Compute forward vector (from camera to target)\n",
    "    forward_vector = target_point - camera_position\n",
    "    forward_vector /= np.linalg.norm(forward_vector)\n",
    "\n",
    "    # Compute right and up vectors\n",
    "    right_vector = np.cross(up_vector, forward_vector)\n",
    "    if np.linalg.norm(right_vector) < 1e-6:\n",
    "        # Adjust up_vector if it's parallel to forward_vector\n",
    "        up_vector = np.array([0, 0, 1]) if up_vector[1] != 1 else np.array([1, 0, 0])\n",
    "        right_vector = np.cross(up_vector, forward_vector)\n",
    "    right_vector /= np.linalg.norm(right_vector)\n",
    "    up_vector = np.cross(forward_vector, right_vector)\n",
    "\n",
    "    # Construct the camera-to-world matrix\n",
    "    camera_pose = np.eye(4)\n",
    "    camera_pose[:3, 0] = right_vector\n",
    "    camera_pose[:3, 1] = up_vector\n",
    "    camera_pose[:3, 2] = forward_vector\n",
    "    camera_pose[:3, 3] = camera_position\n",
    "    return camera_pose\n",
    "\n",
    "def convert_c2w_to_opengl_view(c2w):\n",
    "    \"\"\"\n",
    "    Convert a camera-to-world (c2w) extrinsic matrix to an OpenGL-compatible view matrix.\n",
    "    \"\"\"\n",
    "    # Invert the camera-to-world matrix to get world-to-camera (view) matrix\n",
    "    world_to_camera = np.linalg.inv(c2w)\n",
    "\n",
    "    # OpenGL requires flipping the Y and Z axes\n",
    "    opengl_to_camera = np.array([\n",
    "        [1,  0,  0, 0],\n",
    "        [0, -1,  0, 0],\n",
    "        [0,  0, -1, 0],\n",
    "        [0,  0,  0, 1]\n",
    "    ])\n",
    "\n",
    "    # Compute the OpenGL view matrix\n",
    "    opengl_view_matrix = world_to_camera @ opengl_to_camera\n",
    "    return opengl_view_matrix\n",
    "\n",
    "def render_cumulative_pts3d_viz(preds, views, output_dir='./output', point_size=5.0, min_conf_thr_percentile=0):\n",
    "    os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "    cumulative_pts = []\n",
    "    cumulative_colors = []\n",
    "\n",
    "    # First, accumulate all points across all frames to compute the scene extents\n",
    "    for i, pred in enumerate(preds):\n",
    "        # Flatten `pts3d` and `img_rgb`\n",
    "        pts3d = to_numpy(pred['pts3d_in_other_view'].cpu().squeeze()).reshape(-1, 3)\n",
    "        img_rgb = to_numpy(views[i]['img'].cpu().squeeze().permute(1, 2, 0)).reshape(-1, 3)\n",
    "\n",
    "        # Apply confidence threshold\n",
    "        conf = to_numpy(pred['conf'].cpu().squeeze()).flatten()\n",
    "        conf_thr = np.percentile(conf, min_conf_thr_percentile)\n",
    "        mask = conf > conf_thr\n",
    "\n",
    "        # Apply the mask to points and colors\n",
    "        pts3d_masked = pts3d[mask]\n",
    "        colors_masked = ((img_rgb[mask] + 1) * 127.5).astype(np.uint8)\n",
    "\n",
    "        # Accumulate masked points and colors\n",
    "        cumulative_pts.append(pts3d_masked)\n",
    "        cumulative_colors.append(colors_masked)\n",
    "\n",
    "    # Combine cumulative points and colors\n",
    "    cumulative_pts_combined = np.concatenate(cumulative_pts, axis=0)\n",
    "    cumulative_colors_combined = np.concatenate(cumulative_colors, axis=0)\n",
    "\n",
    "    # Verify that we have valid points\n",
    "    if cumulative_pts_combined.shape[0] == 0:\n",
    "        print(\"No points to render. Exiting.\")\n",
    "        return\n",
    "\n",
    "    # Compute the center and extents of the cumulative point cloud\n",
    "    point_cloud_center = np.mean(cumulative_pts_combined, axis=0)\n",
    "    min_coords = np.min(cumulative_pts_combined, axis=0)\n",
    "    max_coords = np.max(cumulative_pts_combined, axis=0)\n",
    "    scene_extent = max_coords - min_coords\n",
    "    max_extent = np.max(scene_extent)\n",
    "\n",
    "    # Debug: Print point cloud stats\n",
    "    print(f\"Point cloud center: {point_cloud_center}\")\n",
    "    print(f\"Scene extents: {scene_extent}\")\n",
    "    print(f\"Max extent: {max_extent}\")\n",
    "\n",
    "    # Adjust camera position based on coordinate system\n",
    "    # Assuming Z-up coordinate system (adjust if necessary)\n",
    "    camera_distance = max_extent * 2  # Adjust multiplier as needed\n",
    "    camera_position = point_cloud_center + np.array([0, 0, camera_distance])  # Camera above the scene\n",
    "    up_vector = np.array([0, 1, 0])  # Y-axis is up in this case\n",
    "\n",
    "    # Create the camera pose looking at the center of the point cloud\n",
    "    camera_pose = create_camera_pose(camera_position, point_cloud_center, up_vector=up_vector)\n",
    "\n",
    "    # Convert to OpenGL view matrix\n",
    "    opengl_camera_pose = convert_c2w_to_opengl_view(camera_pose)\n",
    "\n",
    "    print(\"Using stationary bird's eye view camera pose.\")\n",
    "\n",
    "    for i in range(len(preds)):\n",
    "        print(f\"Rendering frame {i}...\")\n",
    "\n",
    "        # For each frame, use the cumulative points up to that frame\n",
    "        cumulative_pts_upto_frame = np.concatenate(cumulative_pts[:i+1], axis=0)\n",
    "        cumulative_colors_upto_frame = np.concatenate(cumulative_colors[:i+1], axis=0)\n",
    "\n",
    "        # Create the Pyrender scene and render\n",
    "        pyrender_scene = pyrender.Scene()\n",
    "        points_mesh = pyrender.Mesh.from_points(cumulative_pts_upto_frame, colors=cumulative_colors_upto_frame)\n",
    "        pyrender_scene.add(points_mesh)\n",
    "\n",
    "        # Set up the camera\n",
    "        camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=16/9)\n",
    "        pyrender_scene.add(camera, pose=opengl_camera_pose)\n",
    "\n",
    "        # Add light\n",
    "        light = pyrender.DirectionalLight(color=np.ones(3), intensity=3.0)\n",
    "        pyrender_scene.add(light, pose=opengl_camera_pose)\n",
    "\n",
    "        # Render the scene\n",
    "        r = pyrender.OffscreenRenderer(viewport_width=1920, viewport_height=1080, point_size=point_size)\n",
    "        color, _ = r.render(pyrender_scene)\n",
    "        r.delete()\n",
    "\n",
    "        # Save the rendered image\n",
    "        frame_filename = os.path.join(output_dir, f'cumulative_{i:03d}.png')\n",
    "        print(f\"Frame {i} saved as {frame_filename}\")\n",
    "        Image.fromarray(color).save(frame_filename)\n",
    "\n",
    "    print(\"Rendering complete. Frames saved as PNG files.\")\n",
    "\n",
    "# Run the rendering function\n",
    "render_cumulative_pts3d_viz(\n",
    "    output['preds'],\n",
    "    output['views'],\n",
    "    output_dir='./output',\n",
    "    point_size=5.0,\n",
    "    min_conf_thr_percentile=0\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import trimesh\n",
    "import pyrender\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Set PyOpenGL platform to EGL for headless rendering\n",
    "os.environ[\"PYOPENGL_PLATFORM\"] = \"egl\"\n",
    "\n",
    "# Load the FUZE bottle trimesh and put it in a scene\n",
    "fuze_trimesh = trimesh.load('output/fuze.obj')\n",
    "mesh = pyrender.Mesh.from_trimesh(fuze_trimesh)\n",
    "scene = pyrender.Scene()\n",
    "scene.add(mesh)\n",
    "\n",
    "# Set up the camera -- z-axis away from the scene, x-axis right, y-axis up\n",
    "camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0)\n",
    "s = np.sqrt(2) / 2\n",
    "camera_pose = np.array([\n",
    "    [0.0, -s,   s,   0.3],\n",
    "    [1.0,  0.0, 0.0, 0.0],\n",
    "    [0.0,  s,   s,   0.35],\n",
    "    [0.0,  0.0, 0.0, 1.0],\n",
    "])\n",
    "scene.add(camera, pose=camera_pose)\n",
    "\n",
    "# Set up the light -- a single spot light in the same spot as the camera\n",
    "light = pyrender.SpotLight(color=np.ones(3), intensity=3.0,\n",
    "                           innerConeAngle=np.pi / 16.0, outerConeAngle=np.pi / 6.0)\n",
    "scene.add(light, pose=camera_pose)\n",
    "\n",
    "# Initialize the offscreen renderer\n",
    "r = pyrender.OffscreenRenderer(640, 480)\n",
    "\n",
    "# Render the scene\n",
    "color, depth = r.render(scene)\n",
    "\n",
    "# Display the images\n",
    "plt.figure()\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.axis('off')\n",
    "plt.imshow(color)\n",
    "plt.title(\"Color\")\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.axis('off')\n",
    "plt.imshow(depth, cmap=plt.cm.gray_r)\n",
    "plt.title(\"Depth\")\n",
    "\n",
    "plt.show()\n",
    "\n",
    "# Clean up the renderer\n",
    "r.delete()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Plot the RGB images\n",
    "plot_rgb_images(output['views'])\n",
    "\n",
    "# Plot the confidence maps\n",
    "plot_confidence_maps(output['preds'])\n",
    "\n",
    "# Plot the 3D points\n",
    "plot_3d_points_with_colors(output['preds'], output['views'], flip_axes=True, as_mesh=False, min_conf_thr_percentile=30, export_ply_path='./output/combined_mesh.ply')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import plotly.graph_objects as go\n",
    "from scipy.linalg import rq\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "def estimate_camera_matrix(world_points, image_points):\n",
    "    \"\"\"\n",
    "    Estimate the camera matrix from 3D world points and 2D image points using DLT.\n",
    "    \n",
    "    Parameters:\n",
    "    world_points (np.ndarray): Array of 3D points in the world coordinates, shape (N, 3).\n",
    "    image_points (np.ndarray): Array of 2D points in the image coordinates, shape (N, 2).\n",
    "    \n",
    "    Returns:\n",
    "    np.ndarray: The 3x4 camera matrix.\n",
    "    \"\"\"\n",
    "    assert world_points.shape[0] == image_points.shape[0], \"Number of points must match\"\n",
    "    num_points = world_points.shape[0]\n",
    "    \n",
    "    # Add homogeneous coordinates to the world points\n",
    "    homogeneous_world_points = np.hstack((world_points, np.ones((num_points, 1))))\n",
    "    \n",
    "    A = []\n",
    "    \n",
    "    for i in range(num_points):\n",
    "        X, Y, Z, _ = homogeneous_world_points[i]\n",
    "        u, v = image_points[i]\n",
    "        \n",
    "        # Two rows of the equation for each point\n",
    "        A.append([X, Y, Z, 1, 0, 0, 0, 0, -u*X, -u*Y, -u*Z, -u])\n",
    "        A.append([0, 0, 0, 0, X, Y, Z, 1, -v*X, -v*Y, -v*Z, -v])\n",
    "    \n",
    "    # Convert A to a numpy array\n",
    "    A = np.array(A)\n",
    "    \n",
    "    # Solve using SVD (Singular Value Decomposition)\n",
    "    U, S, Vt = np.linalg.svd(A)\n",
    "    \n",
    "    # The last row of Vt (or last column of V) is the solution\n",
    "    P = Vt[-1].reshape(3, 4)\n",
    "    \n",
    "    return P\n",
    "\n",
    "def decompose_camera_matrix(P):\n",
    "    \"\"\"\n",
    "    Decompose the camera matrix into intrinsic and extrinsic matrices.\n",
    "    \n",
    "    Parameters:\n",
    "    P (np.ndarray): The 3x4 camera matrix.\n",
    "    \n",
    "    Returns:\n",
    "    K (np.ndarray): The 3x3 intrinsic matrix.\n",
    "    R (np.ndarray): The 3x3 rotation matrix.\n",
    "    t (np.ndarray): The 3x1 translation vector.\n",
    "    \"\"\"\n",
    "    # Extract the camera matrix K and rotation matrix R using RQ decomposition\n",
    "    M = P[:, :3]  # The first 3x3 part of P\n",
    "    \n",
    "    # RQ Decomposition of M\n",
    "    K, R = rq(M)\n",
    "    \n",
    "    # Normalize K so that K[2,2] = 1\n",
    "    K /= K[2, 2]\n",
    "    \n",
    "    # Compute translation vector\n",
    "    t = np.dot(np.linalg.inv(K), P[:, 3])\n",
    "    \n",
    "    return K, R, t\n",
    "\n",
    "def plot_camera_cones(fig, R, t, K, color='blue', scale=0.1):\n",
    "    \"\"\"\n",
    "    Plot the camera as a cone in 3D space based on the intrinsic matrix K for focal length.\n",
    "    \n",
    "    Parameters:\n",
    "    fig (plotly.graph_objects.Figure): The existing Plotly figure.\n",
    "    R (np.ndarray): The 3x3 rotation matrix.\n",
    "    t (np.ndarray): The 3x1 translation vector.\n",
    "    K (np.ndarray): The 3x3 intrinsic matrix.\n",
    "    color (str): Color of the camera cone.\n",
    "    scale (float): Scale factor for the size of the cone base.\n",
    "    \"\"\"\n",
    "    # The focal length is the element K[0, 0] (assuming fx and fy are equal)\n",
    "    focal_length = K[0, 0] / K[2, 2]\n",
    "\n",
    "    # The camera center (apex of the cone)\n",
    "    camera_center = -R.T @ t\n",
    "\n",
    "    # Define the orientation of the cone based on the rotation matrix\n",
    "    direction = R.T @ np.array([0, 0, 1])  # Camera looks along the +Z axis in camera space\n",
    "\n",
    "    # Scale the direction by the focal length\n",
    "    direction = direction * focal_length\n",
    "\n",
    "    # Plot the camera cone\n",
    "    fig.add_trace(go.Cone(\n",
    "        x=[camera_center[0]],\n",
    "        y=[camera_center[1]],\n",
    "        z=[camera_center[2]],\n",
    "        u=[direction[0]],\n",
    "        v=[direction[1]],\n",
    "        w=[direction[2]],\n",
    "        colorscale=[[0, color], [1, color]],  # Single color for the cone\n",
    "        showscale=False,\n",
    "        sizemode=\"absolute\",\n",
    "        sizeref=scale,  # The size of the cone base\n",
    "        anchor=\"tip\",  # The tip of the cone is the camera center\n",
    "        name=\"Camera Cone\"\n",
    "    ))\n",
    "\n",
    "def plot_3d_points_with_estimated_camera(output, fig, camera_poses, min_conf_thr_percentile=80):\n",
    "    \"\"\"\n",
    "    Plot 3D points together with estimated camera cones in the same plot.\n",
    "    \n",
    "    Parameters:\n",
    "    output (dict): The output containing 'preds' with 3D points and corresponding 2D image points.\n",
    "    fig (plotly.graph_objects.Figure): The existing 3D plot.\n",
    "    camera_poses (list): List of estimated camera poses.\n",
    "    min_conf_thr_percentile (int): Percentile threshold for confidence values to filter points.\n",
    "    \"\"\"\n",
    "    # Plot the 3D points first\n",
    "    all_points = []\n",
    "    all_colors = []\n",
    "\n",
    "    for i, pred in enumerate(output['preds']):\n",
    "        pts3d = pred['pts3d_in_other_view'].cpu().numpy().squeeze()  # 3D points\n",
    "        img_rgb = output['views'][i]['img'].cpu().numpy().squeeze().transpose(1, 2, 0)  # RGB image (224x224)\n",
    "        conf = pred['conf'].cpu().numpy().squeeze()  # Confidence map\n",
    "\n",
    "        # Apply confidence threshold\n",
    "        conf_thr = np.percentile(conf, min_conf_thr_percentile)\n",
    "        mask = conf > conf_thr\n",
    "\n",
    "        # Rescale RGB values from [-1, 1] to [0, 255]\n",
    "        img_rgb = ((img_rgb + 1) * 127.5).astype(np.uint8).clip(0, 255)\n",
    "\n",
    "        # Flatten the points and colors, and apply mask\n",
    "        x, y, z = pts3d[..., 0].flatten(), pts3d[..., 1].flatten(), pts3d[..., 2].flatten()\n",
    "        r, g, b = img_rgb[..., 0].flatten(), img_rgb[..., 1].flatten(), img_rgb[..., 2].flatten()\n",
    "        x, y, z = x[mask.flatten()], y[mask.flatten()], z[mask.flatten()]\n",
    "        r, g, b = r[mask.flatten()], g[mask.flatten()], b[mask.flatten()]\n",
    "\n",
    "        colors = ['rgb({}, {}, {})'.format(r[j], g[j], b[j]) for j in range(len(r))]\n",
    "\n",
    "        # Add points to the plot\n",
    "        fig.add_trace(go.Scatter3d(\n",
    "            x=x, y=y, z=z,\n",
    "            mode='markers',\n",
    "            marker=dict(size=2, opacity=0.8, color=colors),\n",
    "            name=f\"View {i} Points\"\n",
    "        ))\n",
    "\n",
    "    # Now, plot the estimated cameras as cones\n",
    "    for i, (R, t, K) in enumerate(camera_poses):\n",
    "        plot_camera_cones(fig, R, t, K, color='blue')\n",
    "\n",
    "    fig.update_layout(\n",
    "        scene=dict(\n",
    "            xaxis_title='X',\n",
    "            yaxis_title='Y',\n",
    "            zaxis_title='Z'\n",
    "        ),\n",
    "        margin=dict(l=0, r=0, b=0, t=40)\n",
    "    )\n",
    "\n",
    "def estimate_camera_poses(output, min_conf_thr_percentile=80):\n",
    "    \"\"\"\n",
    "    Estimate camera poses from 3D points and 2D image points.\n",
    "    \n",
    "    Parameters:\n",
    "    output (dict): The output containing 'preds' with 3D points and corresponding 2D image points.\n",
    "    min_conf_thr_percentile (int): Percentile threshold for confidence values to filter points.\n",
    "    \n",
    "    Returns:\n",
    "    list: A list of camera poses (R, t, K) where R is rotation, t is translation, and K is intrinsic matrix.\n",
    "    \"\"\"\n",
    "    camera_poses = []\n",
    "    \n",
    "    # Loop through all views in output['preds']\n",
    "    for i, pred in enumerate(output['preds']):\n",
    "        # Get the 3D points and confidence map for the current view\n",
    "        world_points = pred['pts3d_in_other_view'].cpu().numpy().squeeze()  # Shape: (272, 512, 3)\n",
    "        conf = pred['conf'].cpu().numpy().squeeze()  # Confidence map\n",
    "\n",
    "        # Determine the confidence threshold based on the percentile\n",
    "        conf_thr = np.percentile(conf, min_conf_thr_percentile)\n",
    "\n",
    "        # Apply confidence mask to filter points\n",
    "        mask = conf > conf_thr\n",
    "        world_points_filtered = world_points[mask]\n",
    "\n",
    "        # Generate 2D pixel coordinates corresponding to the filtered points\n",
    "        h, w, _ = world_points.shape\n",
    "        image_points = np.indices((h, w)).reshape(2, -1).T  # Shape: (N, 2)\n",
    "        image_points_filtered = image_points[mask.flatten()]  # Apply mask to 2D points\n",
    "\n",
    "        if world_points_filtered.shape[0] == 0:\n",
    "            print(f\"View {i}: No points above confidence threshold. Skipping camera estimation.\")\n",
    "            continue\n",
    "\n",
    "        # Estimate the camera matrix\n",
    "        P = estimate_camera_matrix(world_points_filtered, image_points_filtered)\n",
    "        print(f\"Camera matrix for view {i}:\\n\", P)\n",
    "\n",
    "        # Decompose into intrinsic and extrinsic matrices\n",
    "        K, R, t = decompose_camera_matrix(P)\n",
    "        print(f\"Intrinsic matrix (K) for view {i}:\\n\", K)\n",
    "        print(f\"Rotation matrix (R) for view {i}:\\n\", R)\n",
    "        print(f\"Translation vector (t) for view {i}:\\n\", t)\n",
    "\n",
    "        # Store the camera pose (rotation, translation, and intrinsic matrix)\n",
    "        camera_poses.append((R, t, K))\n",
    "    \n",
    "    return camera_poses\n",
    "\n",
    "# Estimate the camera poses first\n",
    "camera_poses = estimate_camera_poses(output, min_conf_thr_percentile=80)\n",
    "\n",
    "# Create a 3D plot and plot the 3D points together with the estimated cameras\n",
    "fig = go.Figure()\n",
    "plot_3d_points_with_estimated_camera(output, fig, camera_poses, min_conf_thr_percentile=50)\n",
    "\n",
    "# Display the final plot with 3D points and camera cones\n",
    "fig.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output['views'][0]['img'].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Align with DTU point cloud"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The Rt matrix of the first image lives at /path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Calibration/cal18/pos_001.txt\n",
    "# it looks like this:\n",
    "# 2607.429996 -3.844898 1498.178098 -533936.661373\n",
    "# -192.076910 2862.552532 681.798177 23434.686572\n",
    "# -0.241605 -0.030951 0.969881 22.540121\n",
    "# I'd like to use this this to rotate an input 3D points to the correct orientation\n",
    "# my 3d points assumes the camera is at (0, 0, 0) and looking at (0, 0, 1)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import trimesh\n",
    "import plotly.graph_objs as go\n",
    "from scipy.linalg import rq\n",
    "\n",
    "def load_camera_matrix(filepath):\n",
    "    \"\"\"Loads the camera calibration matrix from the given file.\"\"\"\n",
    "    with open(filepath, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "    camera_matrix = np.array([list(map(float, line.split())) for line in lines])\n",
    "    return camera_matrix\n",
    "\n",
    "def decompose_camera_matrix(camera_matrix):\n",
    "    \"\"\"Decomposes the camera calibration matrix into intrinsic matrix (K), rotation matrix (R), and translation vector (t).\"\"\"\n",
    "    # The camera matrix is 3x4\n",
    "    M = camera_matrix[:, :3]\n",
    "    \n",
    "    # RQ decomposition to separate K and R\n",
    "    K, R = rq(M)\n",
    "    \n",
    "    # Normalize K to ensure the sign of the diagonal is positive\n",
    "    T = np.diag(np.sign(np.diag(K)))\n",
    "    K = K @ T\n",
    "    R = T @ R\n",
    "    \n",
    "    # Compute translation vector t\n",
    "    t = np.linalg.inv(K) @ camera_matrix[:, 3]\n",
    "    \n",
    "    # Camera position C = -R^T * t\n",
    "    camera_position = -R.T @ t\n",
    "    \n",
    "    return K, R, t, camera_position\n",
    "\n",
    "def apply_transformation_to_point_cloud(ply_filepath, camera_matrix_filepath):\n",
    "    \"\"\"Applies the rotation and translation from the decomposed camera matrix to a point cloud loaded from a .ply file.\"\"\"\n",
    "    \n",
    "    # Load the point cloud\n",
    "    point_cloud = trimesh.load(ply_filepath)\n",
    "    \n",
    "    # Load and decompose the camera matrix\n",
    "    camera_matrix = load_camera_matrix(camera_matrix_filepath)\n",
    "    K, R, t, camera_position = decompose_camera_matrix(camera_matrix)\n",
    "\n",
    "    \n",
    "    # print point cloud range before transformation\n",
    "    print(f\"X range: {np.min(point_cloud.vertices[:, 0])} - {np.max(point_cloud.vertices[:, 0])} = {np.max(point_cloud.vertices[:, 0]) - np.min(point_cloud.vertices[:, 0])}\")\n",
    "    print(f\"Y range: {np.min(point_cloud.vertices[:, 1])} - {np.max(point_cloud.vertices[:, 1])} = {np.max(point_cloud.vertices[:, 1]) - np.min(point_cloud.vertices[:, 1])}\")\n",
    "    print(f\"Z range: {np.min(point_cloud.vertices[:, 2])} - {np.max(point_cloud.vertices[:, 2])} = {np.max(point_cloud.vertices[:, 2]) - np.min(point_cloud.vertices[:, 2])}\")\n",
    "\n",
    "    # prting the camera position\n",
    "    print(f\"Camera position: {camera_position}\")\n",
    "    \n",
    "    # Apply the rotation matrix to the point cloud vertices\n",
    "    rotated_points = (R @ point_cloud.vertices.T).T\n",
    "    \n",
    "    # Apply translation\n",
    "    transformed_points = rotated_points + t\n",
    "    \n",
    "    # Print the range of the transformed points per axis\n",
    "    print(f\"X range: {np.min(transformed_points[:, 0])} - {np.max(transformed_points[:, 0])} = {np.max(transformed_points[:, 0]) - np.min(transformed_points[:, 0])}\")\n",
    "    print(f\"Y range: {np.min(transformed_points[:, 1])} - {np.max(transformed_points[:, 1])} = {np.max(transformed_points[:, 1]) - np.min(transformed_points[:, 1])}\")\n",
    "    print(f\"Z range: {np.min(transformed_points[:, 2])} - {np.max(transformed_points[:, 2])} = {np.max(transformed_points[:, 2]) - np.min(transformed_points[:, 2])}\")\n",
    "    \n",
    "    # Create a new point cloud with rotated and translated points\n",
    "    transformed_point_cloud = trimesh.PointCloud(vertices=transformed_points, colors=point_cloud.colors)\n",
    "    \n",
    "    return transformed_point_cloud\n",
    "\n",
    "def plot_point_cloud(point_cloud, title=\"Transformed Point Cloud\"):\n",
    "    \"\"\"Visualizes a point cloud using Plotly.\"\"\"\n",
    "    x = point_cloud.vertices[:, 0]\n",
    "    y = point_cloud.vertices[:, 1]\n",
    "    z = point_cloud.vertices[:, 2]\n",
    "    colors = point_cloud.colors / 255.0  # Normalize colors to [0, 1] for Plotly\n",
    "    \n",
    "    fig = go.Figure(data=[go.Scatter3d(\n",
    "        x=x, y=y, z=z,\n",
    "        mode='markers',\n",
    "        marker=dict(\n",
    "            size=2,\n",
    "            color=colors,\n",
    "            opacity=0.8\n",
    "        )\n",
    "    )])\n",
    "    \n",
    "    fig.update_layout(\n",
    "        title=title,\n",
    "        scene=dict(\n",
    "            xaxis_title='X',\n",
    "            yaxis_title='Y',\n",
    "            zaxis_title='Z'\n",
    "        ),\n",
    "        margin=dict(l=0, r=0, b=0, t=40),\n",
    "        height=800\n",
    "    )\n",
    "    \n",
    "    fig.show()\n",
    "\n",
    "# Example usage:\n",
    "ply_filepath = '/path/to/combined_mesh.ply'\n",
    "camera_matrix_filepath = '/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Calibration/cal18/pos_001.txt'\n",
    "\n",
    "transformed_point_cloud = apply_transformation_to_point_cloud(ply_filepath, camera_matrix_filepath)\n",
    "\n",
    "# Save the transformed point cloud to a new .ply file\n",
    "transformed_point_cloud.export('/path/to/transformed_output.ply')\n",
    "\n",
    "# Visualize the transformed point cloud\n",
    "plot_point_cloud(transformed_point_cloud)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import trimesh\n",
    "\n",
    "def load_and_print_xyz_ranges(ply_filepath):\n",
    "    \"\"\"Loads a point cloud from a .ply file and prints the XYZ ranges.\"\"\"\n",
    "    \n",
    "    # Load the point cloud\n",
    "    point_cloud = trimesh.load(ply_filepath)\n",
    "    \n",
    "    # Extract the vertices (XYZ coordinates)\n",
    "    vertices = point_cloud.vertices\n",
    "    \n",
    "    # Calculate the ranges for X, Y, and Z\n",
    "    x_min, x_max = np.min(vertices[:, 0]), np.max(vertices[:, 0])\n",
    "    y_min, y_max = np.min(vertices[:, 1]), np.max(vertices[:, 1])\n",
    "    z_min, z_max = np.min(vertices[:, 2]), np.max(vertices[:, 2])\n",
    "    \n",
    "    # Print the ranges\n",
    "    print(f\"X range: {x_min} - {x_max} = {x_max - x_min}\")\n",
    "    print(f\"Y range: {y_min} - {y_max} = {y_max - y_min}\")\n",
    "    print(f\"Z range: {z_min} - {z_max} = {z_max - z_min}\")\n",
    "\n",
    "# Example usage:\n",
    "reference_ply_filepath = '/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Points/stl/stl006_total.ply'\n",
    "\n",
    "load_and_print_xyz_ranges(reference_ply_filepath)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import plotly.graph_objs as go\n",
    "import os\n",
    "from scipy.linalg import rq\n",
    "\n",
    "def load_camera_matrix(filepath):\n",
    "    \"\"\"Loads the camera calibration matrix from the given file.\"\"\"\n",
    "    with open(filepath, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "    camera_matrix = np.array([list(map(float, line.split())) for line in lines])\n",
    "    return camera_matrix\n",
    "\n",
    "def decompose_camera_matrix(camera_matrix):\n",
    "    \"\"\"Decomposes the camera calibration matrix into intrinsic matrix (K), rotation matrix (R), and translation vector (t).\"\"\"\n",
    "    # The camera matrix is 3x4\n",
    "    M = camera_matrix[:, :3]\n",
    "    \n",
    "    # RQ decomposition to separate K and R\n",
    "    K, R = rq(M)\n",
    "    \n",
    "    # Normalize K to ensure the sign of the diagonal is positive\n",
    "    T = np.diag(np.sign(np.diag(K)))\n",
    "    K = K @ T\n",
    "    R = T @ R\n",
    "    \n",
    "    # Compute translation vector t\n",
    "    t = np.linalg.inv(K) @ camera_matrix[:, 3]\n",
    "    \n",
    "    # Camera position C = -R^T * t\n",
    "    camera_position = -R.T @ t\n",
    "    \n",
    "    return K, R, t, camera_position\n",
    "\n",
    "def plot_camera_poses(base_path, pose_count):\n",
    "    \"\"\"Plots all camera poses and visualizes them in Plotly.\"\"\"\n",
    "    camera_positions = []\n",
    "    camera_orientations = []\n",
    "    \n",
    "    for i in range(1, pose_count + 1):\n",
    "        filepath = os.path.join(base_path, f'pos_{i:03d}.txt')\n",
    "        camera_matrix = load_camera_matrix(filepath)\n",
    "        \n",
    "        # Print the full camera matrix\n",
    "        print(f\"Camera Matrix {i}:\\n{camera_matrix}\\n\")\n",
    "        \n",
    "        K, R, t, camera_position = decompose_camera_matrix(camera_matrix)\n",
    "        \n",
    "        # Print the decomposed matrices\n",
    "        print(f\"Intrinsic Matrix (K) {i}:\\n{K}\\n\")\n",
    "        print(f\"Rotation Matrix (R) {i}:\\n{R}\\n\")\n",
    "        print(f\"Translation Vector (t) {i}:\\n{t}\\n\")\n",
    "        print(f\"Camera Position {i}: {camera_position}\\n\")\n",
    "        \n",
    "        # Camera direction (assuming camera is looking along -Z in its own coordinate system)\n",
    "        camera_direction = R.T @ np.array([0, 0, -1])\n",
    "        \n",
    "        camera_positions.append(camera_position)\n",
    "        camera_orientations.append(camera_direction)\n",
    "    \n",
    "    # Convert lists to numpy arrays\n",
    "    camera_positions = np.array(camera_positions)\n",
    "    camera_orientations = np.array(camera_orientations)\n",
    "    \n",
    "    # Create the 3D scatter plot for camera positions\n",
    "    scatter = go.Scatter3d(\n",
    "        x=camera_positions[:, 0],\n",
    "        y=camera_positions[:, 1],\n",
    "        z=camera_positions[:, 2],\n",
    "        mode='markers',\n",
    "        marker=dict(size=5, color='blue'),\n",
    "        name='Camera Positions'\n",
    "    )\n",
    "    \n",
    "    # Create the 3D quiver plot for camera orientations\n",
    "    quiver = go.Cone(\n",
    "        x=camera_positions[:, 0],\n",
    "        y=camera_positions[:, 1],\n",
    "        z=camera_positions[:, 2],\n",
    "        u=camera_orientations[:, 0],\n",
    "        v=camera_orientations[:, 1],\n",
    "        w=camera_orientations[:, 2],\n",
    "        sizemode='scaled',\n",
    "        sizeref=2,\n",
    "        colorscale='Blues',\n",
    "        name='Camera Orientations'\n",
    "    )\n",
    "    \n",
    "    # Set up the layout\n",
    "    layout = go.Layout(\n",
    "        title='Camera Poses Visualization',\n",
    "        scene=dict(\n",
    "            xaxis=dict(title='X'),\n",
    "            yaxis=dict(title='Y'),\n",
    "            zaxis=dict(title='Z'),\n",
    "        ),\n",
    "        margin=dict(l=0, r=0, b=0, t=40),\n",
    "        height=800\n",
    "    )\n",
    "    \n",
    "    # Create the figure and show it\n",
    "    fig = go.Figure(data=[scatter, quiver], layout=layout)\n",
    "    fig.show()\n",
    "\n",
    "# Example usage:\n",
    "base_path = '/path/to/dust3r_data/datasets_raw/DTU/SampleSet/MVS Data/Calibration/cal18'\n",
    "pose_count = 49  # Adjust this according to the number of poses available\n",
    "\n",
    "plot_camera_poses(base_path, pose_count)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fast3r",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
